MLIR 23.0.0git
ArithOps.cpp
Go to the documentation of this file.
1//===- ArithOps.cpp - MLIR Arith dialect ops implementation -----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include <cassert>
10#include <cstdint>
11#include <functional>
12#include <utility>
13
17#include "mlir/IR/Builders.h"
20#include "mlir/IR/Matchers.h"
25
26#include "llvm/ADT/APFloat.h"
27#include "llvm/ADT/APInt.h"
28#include "llvm/ADT/APSInt.h"
29#include "llvm/ADT/FloatingPointMode.h"
30#include "llvm/ADT/STLExtras.h"
31#include "llvm/ADT/SmallVector.h"
32#include "llvm/ADT/TypeSwitch.h"
33
34using namespace mlir;
35using namespace mlir::arith;
36
37/// Default rounding mode according to default LLVM floating-point environment.
38static constexpr llvm::RoundingMode kDefaultRoundingMode =
39 llvm::RoundingMode::NearestTiesToEven;
40
41//===----------------------------------------------------------------------===//
42// Pattern helpers
43//===----------------------------------------------------------------------===//
44
45static IntegerAttr
48 function_ref<APInt(const APInt &, const APInt &)> binFn) {
49 const APInt &lhsVal = llvm::cast<IntegerAttr>(lhs).getValue();
50 const APInt &rhsVal = llvm::cast<IntegerAttr>(rhs).getValue();
51 APInt value = binFn(lhsVal, rhsVal);
52 return IntegerAttr::get(res.getType(), value);
53}
54
55static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res,
57 return applyToIntegerAttrs(builder, res, lhs, rhs, std::plus<APInt>());
58}
59
60static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res,
62 return applyToIntegerAttrs(builder, res, lhs, rhs, std::minus<APInt>());
63}
64
65static IntegerAttr mulIntegerAttrs(PatternRewriter &builder, Value res,
67 return applyToIntegerAttrs(builder, res, lhs, rhs, std::multiplies<APInt>());
68}
69
70// Merge overflow flags from 2 ops, selecting the most conservative combination.
71static IntegerOverflowFlagsAttr
72mergeOverflowFlags(IntegerOverflowFlagsAttr val1,
73 IntegerOverflowFlagsAttr val2) {
74 return IntegerOverflowFlagsAttr::get(val1.getContext(),
75 val1.getValue() & val2.getValue());
76}
77
78/// Invert an integer comparison predicate.
79arith::CmpIPredicate arith::invertPredicate(arith::CmpIPredicate pred) {
80 switch (pred) {
81 case arith::CmpIPredicate::eq:
82 return arith::CmpIPredicate::ne;
83 case arith::CmpIPredicate::ne:
84 return arith::CmpIPredicate::eq;
85 case arith::CmpIPredicate::slt:
86 return arith::CmpIPredicate::sge;
87 case arith::CmpIPredicate::sle:
88 return arith::CmpIPredicate::sgt;
89 case arith::CmpIPredicate::sgt:
90 return arith::CmpIPredicate::sle;
91 case arith::CmpIPredicate::sge:
92 return arith::CmpIPredicate::slt;
93 case arith::CmpIPredicate::ult:
94 return arith::CmpIPredicate::uge;
95 case arith::CmpIPredicate::ule:
96 return arith::CmpIPredicate::ugt;
97 case arith::CmpIPredicate::ugt:
98 return arith::CmpIPredicate::ule;
99 case arith::CmpIPredicate::uge:
100 return arith::CmpIPredicate::ult;
101 }
102 llvm_unreachable("unknown cmpi predicate kind");
103}
104
105/// Equivalent to
106/// convertRoundingModeToLLVM(convertArithRoundingModeToLLVM(roundingMode)).
107///
108/// Not possible to implement as chain of calls as this would introduce a
109/// circular dependency with MLIRArithAttrToLLVMConversion and make arith depend
110/// on the LLVM dialect and on translation to LLVM.
111static llvm::RoundingMode
112convertArithRoundingModeToLLVMIR(std::optional<RoundingMode> roundingMode) {
113 if (!roundingMode)
115 switch (*roundingMode) {
116 case RoundingMode::downward:
117 return llvm::RoundingMode::TowardNegative;
118 case RoundingMode::to_nearest_away:
119 return llvm::RoundingMode::NearestTiesToAway;
120 case RoundingMode::to_nearest_even:
121 return llvm::RoundingMode::NearestTiesToEven;
122 case RoundingMode::toward_zero:
123 return llvm::RoundingMode::TowardZero;
124 case RoundingMode::upward:
125 return llvm::RoundingMode::TowardPositive;
126 }
127 llvm_unreachable("Unhandled rounding mode");
128}
129
130static arith::CmpIPredicateAttr invertPredicate(arith::CmpIPredicateAttr pred) {
131 return arith::CmpIPredicateAttr::get(pred.getContext(),
132 invertPredicate(pred.getValue()));
133}
134
136 Type elemTy = getElementTypeOrSelf(type);
137 if (elemTy.isIntOrFloat())
138 return elemTy.getIntOrFloatBitWidth();
139
140 return -1;
141}
142
144 return getScalarOrElementWidth(value.getType());
145}
146
147static FailureOr<APInt> getIntOrSplatIntValue(Attribute attr) {
148 APInt value;
149 if (matchPattern(attr, m_ConstantInt(&value)))
150 return value;
151
152 return failure();
153}
154
155static Attribute getBoolAttribute(Type type, bool value) {
156 auto boolAttr = BoolAttr::get(type.getContext(), value);
157 ShapedType shapedType = dyn_cast_or_null<ShapedType>(type);
158 if (!shapedType)
159 return boolAttr;
160 // DenseElementsAttr requires a static shape.
161 if (!shapedType.hasStaticShape())
162 return {};
163 return DenseElementsAttr::get(shapedType, boolAttr);
164}
165
166//===----------------------------------------------------------------------===//
167// TableGen'd canonicalization patterns
168//===----------------------------------------------------------------------===//
169
170namespace {
171#include "ArithCanonicalization.inc"
172} // namespace
173
174//===----------------------------------------------------------------------===//
175// Common helpers
176//===----------------------------------------------------------------------===//
177
178/// Return the type of the same shape (scalar, vector or tensor) containing i1.
180 auto i1Type = IntegerType::get(type.getContext(), 1);
181 if (auto shapedType = dyn_cast<ShapedType>(type))
182 return shapedType.cloneWith(std::nullopt, i1Type);
183 if (llvm::isa<UnrankedTensorType>(type))
184 return UnrankedTensorType::get(i1Type);
185 return i1Type;
186}
187
188//===----------------------------------------------------------------------===//
189// ConstantOp
190//===----------------------------------------------------------------------===//
191
192void arith::ConstantOp::getAsmResultNames(
193 function_ref<void(Value, StringRef)> setNameFn) {
194 auto type = getType();
195 if (auto intCst = dyn_cast<IntegerAttr>(getValue())) {
196 auto intType = dyn_cast<IntegerType>(type);
197
198 // Sugar i1 constants with 'true' and 'false'.
199 if (intType && intType.getWidth() == 1)
200 return setNameFn(getResult(), (intCst.getInt() ? "true" : "false"));
201
202 // Otherwise, build a complex name with the value and type.
203 SmallString<32> specialNameBuffer;
204 llvm::raw_svector_ostream specialName(specialNameBuffer);
205 specialName << 'c' << intCst.getValue();
206 if (intType)
207 specialName << '_' << type;
208 setNameFn(getResult(), specialName.str());
209 } else {
210 setNameFn(getResult(), "cst");
211 }
212}
213
214/// TODO: disallow arith.constant to return anything other than signless integer
215/// or float like.
216LogicalResult arith::ConstantOp::verify() {
217 auto type = getType();
218 // Integer values must be signless.
219 if (llvm::isa<IntegerType>(type) &&
220 !llvm::cast<IntegerType>(type).isSignless())
221 return emitOpError("integer return type must be signless");
222 // Any float or elements attribute are acceptable.
223 if (!llvm::isa<IntegerAttr, FloatAttr, ElementsAttr>(getValue())) {
224 return emitOpError(
225 "value must be an integer, float, or elements attribute");
226 }
227
228 // Note, we could relax this for vectors with 1 scalable dim, e.g.:
229 // * arith.constant dense<[[3, 3], [1, 1]]> : vector<2 x [2] x i32>
230 // However, this would most likely require updating the lowerings to LLVM.
231 if (isa<ScalableVectorType>(type) && !isa<SplatElementsAttr>(getValue()))
232 return emitOpError(
233 "initializing scalable vectors with elements attribute is not supported"
234 " unless it's a vector splat");
235 return success();
236}
237
238bool arith::ConstantOp::isBuildableWith(Attribute value, Type type) {
239 // The value's type must be the same as the provided type.
240 auto typedAttr = dyn_cast<TypedAttr>(value);
241 if (!typedAttr || typedAttr.getType() != type)
242 return false;
243 // Integer values must be signless.
244 if (auto intType = dyn_cast<IntegerType>(getElementTypeOrSelf(type))) {
245 if (!intType.isSignless())
246 return false;
247 }
248 // Integer, float, and element attributes are buildable.
249 return llvm::isa<IntegerAttr, FloatAttr, ElementsAttr>(value);
250}
251
252ConstantOp arith::ConstantOp::materialize(OpBuilder &builder, Attribute value,
253 Type type, Location loc) {
254 if (isBuildableWith(value, type))
255 return arith::ConstantOp::create(builder, loc, cast<TypedAttr>(value));
256 return nullptr;
257}
258
259OpFoldResult arith::ConstantOp::fold(FoldAdaptor adaptor) { return getValue(); }
260
262 int64_t value, unsigned width) {
263 auto type = builder.getIntegerType(width);
264 arith::ConstantOp::build(builder, result, type,
265 builder.getIntegerAttr(type, value));
266}
267
269 Location location,
271 unsigned width) {
272 mlir::OperationState state(location, getOperationName());
273 build(builder, state, value, width);
274 auto result = dyn_cast<ConstantIntOp>(builder.create(state));
275 assert(result && "builder didn't return the right type");
276 return result;
277}
278
281 unsigned width) {
282 return create(builder, builder.getLoc(), value, width);
283}
284
286 Type type, int64_t value) {
287 arith::ConstantOp::build(builder, result, type,
288 builder.getIntegerAttr(type, value));
289}
290
292 Location location, Type type,
293 int64_t value) {
294 mlir::OperationState state(location, getOperationName());
295 build(builder, state, type, value);
296 auto result = dyn_cast<ConstantIntOp>(builder.create(state));
297 assert(result && "builder didn't return the right type");
298 return result;
299}
300
302 Type type, int64_t value) {
303 return create(builder, builder.getLoc(), type, value);
304}
305
307 Type type, const APInt &value) {
308 arith::ConstantOp::build(builder, result, type,
309 builder.getIntegerAttr(type, value));
310}
311
313 Location location, Type type,
314 const APInt &value) {
315 mlir::OperationState state(location, getOperationName());
316 build(builder, state, type, value);
317 auto result = dyn_cast<ConstantIntOp>(builder.create(state));
318 assert(result && "builder didn't return the right type");
319 return result;
320}
321
323 Type type,
324 const APInt &value) {
325 return create(builder, builder.getLoc(), type, value);
326}
327
329 if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
330 return constOp.getType().isSignlessInteger();
331 return false;
332}
333
335 FloatType type, const APFloat &value) {
336 arith::ConstantOp::build(builder, result, type,
337 builder.getFloatAttr(type, value));
338}
339
341 Location location,
342 FloatType type,
343 const APFloat &value) {
344 mlir::OperationState state(location, getOperationName());
345 build(builder, state, type, value);
346 auto result = dyn_cast<ConstantFloatOp>(builder.create(state));
347 assert(result && "builder didn't return the right type");
348 return result;
349}
350
353 const APFloat &value) {
354 return create(builder, builder.getLoc(), type, value);
355}
356
358 if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
359 return llvm::isa<FloatType>(constOp.getType());
360 return false;
361}
362
364 int64_t value) {
365 arith::ConstantOp::build(builder, result, builder.getIndexType(),
366 builder.getIndexAttr(value));
367}
368
370 Location location,
371 int64_t value) {
372 mlir::OperationState state(location, getOperationName());
373 build(builder, state, value);
374 auto result = dyn_cast<ConstantIndexOp>(builder.create(state));
375 assert(result && "builder didn't return the right type");
376 return result;
377}
378
381 return create(builder, builder.getLoc(), value);
382}
383
385 if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
386 return constOp.getType().isIndex();
387 return false;
388}
389
391 Type type) {
392 // TODO: Incorporate this check to `FloatAttr::get*`.
393 assert(!isa<Float8E8M0FNUType>(getElementTypeOrSelf(type)) &&
394 "type doesn't have a zero representation");
395 TypedAttr zeroAttr = builder.getZeroAttr(type);
396 assert(zeroAttr && "unsupported type for zero attribute");
397 return arith::ConstantOp::create(builder, loc, zeroAttr);
398}
399
400//===----------------------------------------------------------------------===//
401// AddIOp
402//===----------------------------------------------------------------------===//
403
404OpFoldResult arith::AddIOp::fold(FoldAdaptor adaptor) {
405 // addi(x, 0) -> x
406 if (matchPattern(adaptor.getRhs(), m_Zero()))
407 return getLhs();
408
409 // addi(subi(a, b), b) -> a
410 if (auto sub = getLhs().getDefiningOp<SubIOp>())
411 if (getRhs() == sub.getRhs())
412 return sub.getLhs();
413
414 // addi(b, subi(a, b)) -> a
415 if (auto sub = getRhs().getDefiningOp<SubIOp>())
416 if (getLhs() == sub.getRhs())
417 return sub.getLhs();
418
420 adaptor.getOperands(),
421 [](APInt a, const APInt &b) { return std::move(a) + b; });
422}
423
424void arith::AddIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
425 MLIRContext *context) {
426 patterns.add<AddIAddConstant, AddISubConstantRHS, AddISubConstantLHS,
427 AddIMulNegativeOneRhs, AddIMulNegativeOneLhs>(context);
428}
429
430//===----------------------------------------------------------------------===//
431// AddUIExtendedOp
432//===----------------------------------------------------------------------===//
433
434std::optional<SmallVector<int64_t, 4>>
435arith::AddUIExtendedOp::getShapeForUnroll() {
436 if (auto vt = dyn_cast<VectorType>(getType(0)))
437 return llvm::to_vector<4>(vt.getShape());
438 return std::nullopt;
439}
440
441// Returns the overflow bit, assuming that `sum` is the result of unsigned
442// addition of `operand` and another number.
443static APInt calculateUnsignedOverflow(const APInt &sum, const APInt &operand) {
444 return sum.ult(operand) ? APInt::getAllOnes(1) : APInt::getZero(1);
445}
446
447LogicalResult
448arith::AddUIExtendedOp::fold(FoldAdaptor adaptor,
449 SmallVectorImpl<OpFoldResult> &results) {
450 Type overflowTy = getOverflow().getType();
451 // addui_extended(x, 0) -> x, false
452 if (matchPattern(getRhs(), m_Zero())) {
453 Builder builder(getContext());
454 auto falseValue = builder.getZeroAttr(overflowTy);
455
456 results.push_back(getLhs());
457 results.push_back(falseValue);
458 return success();
459 }
460
461 // addui_extended(constant_a, constant_b) -> constant_sum, constant_carry
462 // Let the `constFoldBinaryOp` utility attempt to fold the sum of both
463 // operands. If that succeeds, calculate the overflow bit based on the sum
464 // and the first (constant) operand, `lhs`.
465 if (Attribute sumAttr = constFoldBinaryOp<IntegerAttr>(
466 adaptor.getOperands(),
467 [](APInt a, const APInt &b) { return std::move(a) + b; })) {
468 // If any operand is poison, propagate poison to both results.
469 if (matchPattern(sumAttr, ub::m_Poison())) {
470 results.push_back(sumAttr);
471 results.push_back(sumAttr);
472 return success();
473 }
474 Attribute overflowAttr = constFoldBinaryOp<IntegerAttr>(
475 ArrayRef({sumAttr, adaptor.getLhs()}),
476 getI1SameShape(llvm::cast<TypedAttr>(sumAttr).getType()),
478 if (!overflowAttr)
479 return failure();
480
481 results.push_back(sumAttr);
482 results.push_back(overflowAttr);
483 return success();
484 }
485
486 return failure();
487}
488
489void arith::AddUIExtendedOp::getCanonicalizationPatterns(
490 RewritePatternSet &patterns, MLIRContext *context) {
491 patterns.add<AddUIExtendedToAddI>(context);
492}
493
494//===----------------------------------------------------------------------===//
495// SubIOp
496//===----------------------------------------------------------------------===//
497
498OpFoldResult arith::SubIOp::fold(FoldAdaptor adaptor) {
499 // subi(x,x) -> 0
500 if (getOperand(0) == getOperand(1)) {
501 auto shapedType = dyn_cast<ShapedType>(getType());
502 // We can't generate a constant with a dynamic shaped tensor.
503 if (!shapedType || shapedType.hasStaticShape())
504 return Builder(getContext()).getZeroAttr(getType());
505 }
506 // subi(x,0) -> x
507 if (matchPattern(adaptor.getRhs(), m_Zero()))
508 return getLhs();
509
510 if (auto add = getLhs().getDefiningOp<AddIOp>()) {
511 // subi(addi(a, b), b) -> a
512 if (getRhs() == add.getRhs())
513 return add.getLhs();
514 // subi(addi(a, b), a) -> b
515 if (getRhs() == add.getLhs())
516 return add.getRhs();
517 }
518
520 adaptor.getOperands(),
521 [](APInt a, const APInt &b) { return std::move(a) - b; });
522}
523
524void arith::SubIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
525 MLIRContext *context) {
526 patterns.add<SubIRHSAddConstant, SubILHSAddConstant, SubIRHSSubConstantRHS,
527 SubIRHSSubConstantLHS, SubILHSSubConstantRHS,
528 SubILHSSubConstantLHS, SubISubILHSRHSLHS>(context);
529}
530
531//===----------------------------------------------------------------------===//
532// MulIOp
533//===----------------------------------------------------------------------===//
534
535OpFoldResult arith::MulIOp::fold(FoldAdaptor adaptor) {
536 // muli(x, 0) -> 0
537 if (matchPattern(adaptor.getRhs(), m_Zero()))
538 return getRhs();
539 // muli(x, 1) -> x
540 if (matchPattern(adaptor.getRhs(), m_One()))
541 return getLhs();
542 // TODO: Handle the overflow case.
543
544 // default folder
546 adaptor.getOperands(),
547 [](const APInt &a, const APInt &b) { return a * b; });
548}
549
550void arith::MulIOp::getAsmResultNames(
551 function_ref<void(Value, StringRef)> setNameFn) {
552 if (!isa<IndexType>(getType()))
553 return;
554
555 // Match vector.vscale by name to avoid depending on the vector dialect (which
556 // is a circular dependency).
557 auto isVscale = [](Operation *op) {
558 return op && op->getName().getStringRef() == "vector.vscale";
559 };
560
561 IntegerAttr baseValue;
562 auto isVscaleExpr = [&](Value a, Value b) {
563 return matchPattern(a, m_Constant(&baseValue)) &&
564 isVscale(b.getDefiningOp());
565 };
566
567 if (!isVscaleExpr(getLhs(), getRhs()) && !isVscaleExpr(getRhs(), getLhs()))
568 return;
569
570 // Name `base * vscale` or `vscale * base` as `c<base_value>_vscale`.
571 SmallString<32> specialNameBuffer;
572 llvm::raw_svector_ostream specialName(specialNameBuffer);
573 specialName << 'c' << baseValue.getInt() << "_vscale";
574 setNameFn(getResult(), specialName.str());
575}
576
577void arith::MulIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
578 MLIRContext *context) {
579 patterns.add<MulIMulIConstant>(context);
580}
581
582//===----------------------------------------------------------------------===//
583// MulSIExtendedOp
584//===----------------------------------------------------------------------===//
585
586std::optional<SmallVector<int64_t, 4>>
587arith::MulSIExtendedOp::getShapeForUnroll() {
588 if (auto vt = dyn_cast<VectorType>(getType(0)))
589 return llvm::to_vector<4>(vt.getShape());
590 return std::nullopt;
591}
592
593LogicalResult
594arith::MulSIExtendedOp::fold(FoldAdaptor adaptor,
595 SmallVectorImpl<OpFoldResult> &results) {
596 // mulsi_extended(x, 0) -> 0, 0
597 if (matchPattern(adaptor.getRhs(), m_Zero())) {
598 Attribute zero = adaptor.getRhs();
599 results.push_back(zero);
600 results.push_back(zero);
601 return success();
602 }
603
604 // mulsi_extended(cst_a, cst_b) -> cst_low, cst_high
605 if (Attribute lowAttr = constFoldBinaryOp<IntegerAttr>(
606 adaptor.getOperands(),
607 [](const APInt &a, const APInt &b) { return a * b; })) {
608 // Invoke the constant fold helper again to calculate the 'high' result.
609 Attribute highAttr = constFoldBinaryOp<IntegerAttr>(
610 adaptor.getOperands(), [](const APInt &a, const APInt &b) {
611 return llvm::APIntOps::mulhs(a, b);
612 });
613 assert(highAttr && "Unexpected constant-folding failure");
614
615 results.push_back(lowAttr);
616 results.push_back(highAttr);
617 return success();
618 }
619
620 return failure();
621}
622
623void arith::MulSIExtendedOp::getCanonicalizationPatterns(
624 RewritePatternSet &patterns, MLIRContext *context) {
625 patterns.add<MulSIExtendedToMulI, MulSIExtendedRHSOne>(context);
626}
627
628//===----------------------------------------------------------------------===//
629// MulUIExtendedOp
630//===----------------------------------------------------------------------===//
631
632std::optional<SmallVector<int64_t, 4>>
633arith::MulUIExtendedOp::getShapeForUnroll() {
634 if (auto vt = dyn_cast<VectorType>(getType(0)))
635 return llvm::to_vector<4>(vt.getShape());
636 return std::nullopt;
637}
638
639LogicalResult
640arith::MulUIExtendedOp::fold(FoldAdaptor adaptor,
641 SmallVectorImpl<OpFoldResult> &results) {
642 // mului_extended(x, 0) -> 0, 0
643 if (matchPattern(adaptor.getRhs(), m_Zero())) {
644 Attribute zero = adaptor.getRhs();
645 results.push_back(zero);
646 results.push_back(zero);
647 return success();
648 }
649
650 // mului_extended(x, 1) -> x, 0
651 if (matchPattern(adaptor.getRhs(), m_One())) {
652 Builder builder(getContext());
653 Attribute zero = builder.getZeroAttr(getLhs().getType());
654 results.push_back(getLhs());
655 results.push_back(zero);
656 return success();
657 }
658
659 // mului_extended(cst_a, cst_b) -> cst_low, cst_high
660 if (Attribute lowAttr = constFoldBinaryOp<IntegerAttr>(
661 adaptor.getOperands(),
662 [](const APInt &a, const APInt &b) { return a * b; })) {
663 // Invoke the constant fold helper again to calculate the 'high' result.
664 Attribute highAttr = constFoldBinaryOp<IntegerAttr>(
665 adaptor.getOperands(), [](const APInt &a, const APInt &b) {
666 return llvm::APIntOps::mulhu(a, b);
667 });
668 assert(highAttr && "Unexpected constant-folding failure");
669
670 results.push_back(lowAttr);
671 results.push_back(highAttr);
672 return success();
673 }
674
675 return failure();
676}
677
678void arith::MulUIExtendedOp::getCanonicalizationPatterns(
679 RewritePatternSet &patterns, MLIRContext *context) {
680 patterns.add<MulUIExtendedToMulI>(context);
681}
682
683//===----------------------------------------------------------------------===//
684// DivUIOp
685//===----------------------------------------------------------------------===//
686
687/// Fold `(a * b) / b -> a`
689 arith::IntegerOverflowFlags ovfFlags) {
690 auto mul = lhs.getDefiningOp<mlir::arith::MulIOp>();
691 if (!mul || !bitEnumContainsAll(mul.getOverflowFlags(), ovfFlags))
692 return {};
693
694 if (mul.getLhs() == rhs)
695 return mul.getRhs();
696
697 if (mul.getRhs() == rhs)
698 return mul.getLhs();
699
700 return {};
701}
702
703OpFoldResult arith::DivUIOp::fold(FoldAdaptor adaptor) {
704 // divui (x, 1) -> x.
705 if (matchPattern(adaptor.getRhs(), m_One()))
706 return getLhs();
707
708 // (a * b) / b -> a
709 if (Value val = foldDivMul(getLhs(), getRhs(), IntegerOverflowFlags::nuw))
710 return val;
711
712 // Don't fold if it would require a division by zero.
713 bool div0 = false;
714 auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
715 [&](APInt a, const APInt &b) {
716 if (div0 || !b) {
717 div0 = true;
718 return a;
719 }
720 return a.udiv(b);
721 });
722
723 return div0 ? Attribute() : result;
724}
725
726/// Returns whether an unsigned division by `divisor` is speculatable.
728 // X / 0 => UB
729 if (matchPattern(divisor, m_IntRangeWithoutZeroU()))
731
733}
734
735Speculation::Speculatability arith::DivUIOp::getSpeculatability() {
736 return getDivUISpeculatability(getRhs());
737}
738
739//===----------------------------------------------------------------------===//
740// DivSIOp
741//===----------------------------------------------------------------------===//
742
743OpFoldResult arith::DivSIOp::fold(FoldAdaptor adaptor) {
744 // divsi (x, 1) -> x.
745 if (matchPattern(adaptor.getRhs(), m_One()))
746 return getLhs();
747
748 // (a * b) / b -> a
749 if (Value val = foldDivMul(getLhs(), getRhs(), IntegerOverflowFlags::nsw))
750 return val;
751
752 // Don't fold if it would overflow or if it requires a division by zero.
753 bool overflowOrDiv0 = false;
755 adaptor.getOperands(), [&](APInt a, const APInt &b) {
756 if (overflowOrDiv0 || !b) {
757 overflowOrDiv0 = true;
758 return a;
759 }
760 return a.sdiv_ov(b, overflowOrDiv0);
761 });
762
763 return overflowOrDiv0 ? Attribute() : result;
764}
765
766/// Returns whether a signed division by `divisor` is speculatable. This
767/// function conservatively assumes that all signed division by -1 are not
768/// speculatable.
770 // X / 0 => UB
771 // INT_MIN / -1 => UB
772 if (matchPattern(divisor, m_IntRangeWithoutZeroS()) &&
775
777}
778
779Speculation::Speculatability arith::DivSIOp::getSpeculatability() {
780 return getDivSISpeculatability(getRhs());
781}
782
783//===----------------------------------------------------------------------===//
784// Ceil and floor division folding helpers
785//===----------------------------------------------------------------------===//
786
787static APInt signedCeilNonnegInputs(const APInt &a, const APInt &b,
788 bool &overflow) {
789 // Returns (a-1)/b + 1
790 APInt one(a.getBitWidth(), 1, true); // Signed value 1.
791 APInt val = a.ssub_ov(one, overflow).sdiv_ov(b, overflow);
792 return val.sadd_ov(one, overflow);
793}
794
795//===----------------------------------------------------------------------===//
796// CeilDivUIOp
797//===----------------------------------------------------------------------===//
798
799OpFoldResult arith::CeilDivUIOp::fold(FoldAdaptor adaptor) {
800 // ceildivui (x, 1) -> x.
801 if (matchPattern(adaptor.getRhs(), m_One()))
802 return getLhs();
803
804 bool overflowOrDiv0 = false;
806 adaptor.getOperands(), [&](APInt a, const APInt &b) {
807 if (overflowOrDiv0 || !b) {
808 overflowOrDiv0 = true;
809 return a;
810 }
811 APInt quotient = a.udiv(b);
812 if (!a.urem(b))
813 return quotient;
814 APInt one(a.getBitWidth(), 1, true);
815 return quotient.uadd_ov(one, overflowOrDiv0);
816 });
817
818 return overflowOrDiv0 ? Attribute() : result;
819}
820
821Speculation::Speculatability arith::CeilDivUIOp::getSpeculatability() {
822 return getDivUISpeculatability(getRhs());
823}
824
825//===----------------------------------------------------------------------===//
826// CeilDivSIOp
827//===----------------------------------------------------------------------===//
828
829OpFoldResult arith::CeilDivSIOp::fold(FoldAdaptor adaptor) {
830 // ceildivsi (x, 1) -> x.
831 if (matchPattern(adaptor.getRhs(), m_One()))
832 return getLhs();
833
834 // Don't fold if it would overflow or if it requires a division by zero.
835 // TODO: This hook won't fold operations where a = MININT, because
836 // negating MININT overflows. This can be improved.
837 bool overflowOrDiv0 = false;
839 adaptor.getOperands(), [&](APInt a, const APInt &b) {
840 if (overflowOrDiv0 || !b) {
841 overflowOrDiv0 = true;
842 return a;
843 }
844 if (!a)
845 return a;
846 // After this point we know that neither a or b are zero.
847 unsigned bits = a.getBitWidth();
848 APInt zero = APInt::getZero(bits);
849 bool aGtZero = a.sgt(zero);
850 bool bGtZero = b.sgt(zero);
851 if (aGtZero && bGtZero) {
852 // Both positive, return ceil(a, b).
853 return signedCeilNonnegInputs(a, b, overflowOrDiv0);
854 }
855
856 // No folding happens if any of the intermediate arithmetic operations
857 // overflows.
858 bool overflowNegA = false;
859 bool overflowNegB = false;
860 bool overflowDiv = false;
861 bool overflowNegRes = false;
862 if (!aGtZero && !bGtZero) {
863 // Both negative, return ceil(-a, -b).
864 APInt posA = zero.ssub_ov(a, overflowNegA);
865 APInt posB = zero.ssub_ov(b, overflowNegB);
866 APInt res = signedCeilNonnegInputs(posA, posB, overflowDiv);
867 overflowOrDiv0 = (overflowNegA || overflowNegB || overflowDiv);
868 return res;
869 }
870 if (!aGtZero && bGtZero) {
871 // A is negative, b is positive, return - ( -a / b).
872 APInt posA = zero.ssub_ov(a, overflowNegA);
873 APInt div = posA.sdiv_ov(b, overflowDiv);
874 APInt res = zero.ssub_ov(div, overflowNegRes);
875 overflowOrDiv0 = (overflowNegA || overflowDiv || overflowNegRes);
876 return res;
877 }
878 // A is positive, b is negative, return - (a / -b).
879 APInt posB = zero.ssub_ov(b, overflowNegB);
880 APInt div = a.sdiv_ov(posB, overflowDiv);
881 APInt res = zero.ssub_ov(div, overflowNegRes);
882
883 overflowOrDiv0 = (overflowNegB || overflowDiv || overflowNegRes);
884 return res;
885 });
886
887 return overflowOrDiv0 ? Attribute() : result;
888}
889
890Speculation::Speculatability arith::CeilDivSIOp::getSpeculatability() {
891 return getDivSISpeculatability(getRhs());
892}
893
894//===----------------------------------------------------------------------===//
895// FloorDivSIOp
896//===----------------------------------------------------------------------===//
897
898OpFoldResult arith::FloorDivSIOp::fold(FoldAdaptor adaptor) {
899 // floordivsi (x, 1) -> x.
900 if (matchPattern(adaptor.getRhs(), m_One()))
901 return getLhs();
902
903 // Don't fold if it would overflow or if it requires a division by zero.
904 bool overflowOrDiv = false;
906 adaptor.getOperands(), [&](APInt a, const APInt &b) {
907 if (b.isZero()) {
908 overflowOrDiv = true;
909 return a;
910 }
911 return a.sfloordiv_ov(b, overflowOrDiv);
912 });
913
914 return overflowOrDiv ? Attribute() : result;
915}
916
917//===----------------------------------------------------------------------===//
918// RemUIOp
919//===----------------------------------------------------------------------===//
920
921OpFoldResult arith::RemUIOp::fold(FoldAdaptor adaptor) {
922 // remui (x, 1) -> 0.
923 if (matchPattern(adaptor.getRhs(), m_One()))
924 return Builder(getContext()).getZeroAttr(getType());
925
926 // Don't fold if it would require a division by zero.
927 bool div0 = false;
928 auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
929 [&](APInt a, const APInt &b) {
930 if (div0 || b.isZero()) {
931 div0 = true;
932 return a;
933 }
934 return a.urem(b);
935 });
936
937 return div0 ? Attribute() : result;
938}
939
940Speculation::Speculatability arith::RemUIOp::getSpeculatability() {
941 return getDivUISpeculatability(getRhs());
942}
943
944//===----------------------------------------------------------------------===//
945// RemSIOp
946//===----------------------------------------------------------------------===//
947
948OpFoldResult arith::RemSIOp::fold(FoldAdaptor adaptor) {
949 // remsi (x, 1) -> 0.
950 if (matchPattern(adaptor.getRhs(), m_One()))
951 return Builder(getContext()).getZeroAttr(getType());
952
953 // Don't fold if it would require a division by zero.
954 bool div0 = false;
955 auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
956 [&](APInt a, const APInt &b) {
957 if (div0 || b.isZero()) {
958 div0 = true;
959 return a;
960 }
961 return a.srem(b);
962 });
963
964 return div0 ? Attribute() : result;
965}
966
967Speculation::Speculatability arith::RemSIOp::getSpeculatability() {
968 // X % 0 => UB
969 // X % -1 is well-defined (always 0), unlike X / -1 which can overflow.
970 if (matchPattern(getRhs(), m_IntRangeWithoutZeroS()))
972
974}
975
976//===----------------------------------------------------------------------===//
977// AndIOp
978//===----------------------------------------------------------------------===//
979
980/// Fold `and(a, and(a, b))` to `and(a, b)`
981static Value foldAndIofAndI(arith::AndIOp op) {
982 for (bool reversePrev : {false, true}) {
983 auto prev = (reversePrev ? op.getRhs() : op.getLhs())
984 .getDefiningOp<arith::AndIOp>();
985 if (!prev)
986 continue;
987
988 Value other = (reversePrev ? op.getLhs() : op.getRhs());
989 if (other != prev.getLhs() && other != prev.getRhs())
990 continue;
991
992 return prev.getResult();
993 }
994 return {};
995}
996
997OpFoldResult arith::AndIOp::fold(FoldAdaptor adaptor) {
998 /// and(x, 0) -> 0
999 if (matchPattern(adaptor.getRhs(), m_Zero()))
1000 return getRhs();
1001 /// and(x, allOnes) -> x
1002 APInt intValue;
1003 if (matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue)) &&
1004 intValue.isAllOnes())
1005 return getLhs();
1006 /// and(x, not(x)) -> 0
1007 if (matchPattern(getRhs(), m_Op<XOrIOp>(matchers::m_Val(getLhs()),
1008 m_ConstantInt(&intValue))) &&
1009 intValue.isAllOnes())
1010 return Builder(getContext()).getZeroAttr(getType());
1011 /// and(not(x), x) -> 0
1012 if (matchPattern(getLhs(), m_Op<XOrIOp>(matchers::m_Val(getRhs()),
1013 m_ConstantInt(&intValue))) &&
1014 intValue.isAllOnes())
1015 return Builder(getContext()).getZeroAttr(getType());
1016
1017 /// and(a, and(a, b)) -> and(a, b)
1018 if (Value result = foldAndIofAndI(*this))
1019 return result;
1020
1022 adaptor.getOperands(),
1023 [](APInt a, const APInt &b) { return std::move(a) & b; });
1024}
1025
1026//===----------------------------------------------------------------------===//
1027// OrIOp
1028//===----------------------------------------------------------------------===//
1029
1030OpFoldResult arith::OrIOp::fold(FoldAdaptor adaptor) {
1031 if (APInt rhsVal; matchPattern(adaptor.getRhs(), m_ConstantInt(&rhsVal))) {
1032 /// or(x, 0) -> x
1033 if (rhsVal.isZero())
1034 return getLhs();
1035 /// or(x, <all ones>) -> <all ones>
1036 if (rhsVal.isAllOnes())
1037 return adaptor.getRhs();
1038 }
1039
1040 APInt intValue;
1041 /// or(x, xor(x, 1)) -> 1
1042 if (matchPattern(getRhs(), m_Op<XOrIOp>(matchers::m_Val(getLhs()),
1043 m_ConstantInt(&intValue))) &&
1044 intValue.isAllOnes())
1045 return getRhs().getDefiningOp<XOrIOp>().getRhs();
1046 /// or(xor(x, 1), x) -> 1
1047 if (matchPattern(getLhs(), m_Op<XOrIOp>(matchers::m_Val(getRhs()),
1048 m_ConstantInt(&intValue))) &&
1049 intValue.isAllOnes())
1050 return getLhs().getDefiningOp<XOrIOp>().getRhs();
1051
1053 adaptor.getOperands(),
1054 [](APInt a, const APInt &b) { return std::move(a) | b; });
1055}
1056
1057//===----------------------------------------------------------------------===//
1058// XOrIOp
1059//===----------------------------------------------------------------------===//
1060
1061OpFoldResult arith::XOrIOp::fold(FoldAdaptor adaptor) {
1062 /// xor(x, 0) -> x
1063 if (matchPattern(adaptor.getRhs(), m_Zero()))
1064 return getLhs();
1065 /// xor(x, x) -> 0
1066 if (getLhs() == getRhs())
1067 return Builder(getContext()).getZeroAttr(getType());
1068 /// xor(xor(x, a), a) -> x
1069 /// xor(xor(a, x), a) -> x
1070 if (arith::XOrIOp prev = getLhs().getDefiningOp<arith::XOrIOp>()) {
1071 if (prev.getRhs() == getRhs())
1072 return prev.getLhs();
1073 if (prev.getLhs() == getRhs())
1074 return prev.getRhs();
1075 }
1076 /// xor(a, xor(x, a)) -> x
1077 /// xor(a, xor(a, x)) -> x
1078 if (arith::XOrIOp prev = getRhs().getDefiningOp<arith::XOrIOp>()) {
1079 if (prev.getRhs() == getLhs())
1080 return prev.getLhs();
1081 if (prev.getLhs() == getLhs())
1082 return prev.getRhs();
1083 }
1084
1086 adaptor.getOperands(),
1087 [](APInt a, const APInt &b) { return std::move(a) ^ b; });
1088}
1089
1090void arith::XOrIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1091 MLIRContext *context) {
1092 patterns.add<XOrINotCmpI, XOrIOfExtUI, XOrIOfExtSI>(context);
1093}
1094
1095//===----------------------------------------------------------------------===//
1096// NegFOp
1097//===----------------------------------------------------------------------===//
1098
1099OpFoldResult arith::NegFOp::fold(FoldAdaptor adaptor) {
1100 /// negf(negf(x)) -> x
1101 if (auto op = this->getOperand().getDefiningOp<arith::NegFOp>())
1102 return op.getOperand();
1103 return constFoldUnaryOp<FloatAttr>(adaptor.getOperands(),
1104 [](const APFloat &a) { return -a; });
1105}
1106
1107//===----------------------------------------------------------------------===//
1108// FlushDenormalsOp
1109//===----------------------------------------------------------------------===//
1110
1111OpFoldResult arith::FlushDenormalsOp::fold(FoldAdaptor adaptor) {
1112 // TODO: Fold flush_denormals if the floating-point type does not support
1113 // denormals. There is currently no API to query this information from
1114 // APFloat.
1115
1116 // flush_denormals(flush_denormals(x)) -> flush_denormals(x)
1117 if (auto op = this->getOperand().getDefiningOp<arith::FlushDenormalsOp>())
1118 return op.getResult();
1119
1120 // Constant-fold flush_denormals if the operand is a constant.
1122 adaptor.getOperands(), [](const APFloat &a) {
1123 if (a.isDenormal())
1124 return APFloat::getZero(a.getSemantics(), a.isNegative());
1125 return a;
1126 });
1127}
1128
1129//===----------------------------------------------------------------------===//
1130// AddFOp
1131//===----------------------------------------------------------------------===//
1132
1133OpFoldResult arith::AddFOp::fold(FoldAdaptor adaptor) {
1134 // addf(x, -0) -> x
1135 if (matchPattern(adaptor.getRhs(), m_NegZeroFloat()))
1136 return getLhs();
1137
1138 auto rm = getRoundingmode();
1140 adaptor.getOperands(), [rm](const APFloat &a, const APFloat &b) {
1141 APFloat result(a);
1142 result.add(b, convertArithRoundingModeToLLVMIR(rm));
1143 return result;
1144 });
1145}
1146
1147//===----------------------------------------------------------------------===//
1148// SubFOp
1149//===----------------------------------------------------------------------===//
1150
1151OpFoldResult arith::SubFOp::fold(FoldAdaptor adaptor) {
1152 // subf(x, +0) -> x
1153 if (matchPattern(adaptor.getRhs(), m_PosZeroFloat()))
1154 return getLhs();
1155
1156 auto rm = getRoundingmode();
1158 adaptor.getOperands(), [rm](const APFloat &a, const APFloat &b) {
1159 APFloat result(a);
1160 result.subtract(b, convertArithRoundingModeToLLVMIR(rm));
1161 return result;
1162 });
1163}
1164
1165//===----------------------------------------------------------------------===//
1166// MaximumFOp
1167//===----------------------------------------------------------------------===//
1168
1169OpFoldResult arith::MaximumFOp::fold(FoldAdaptor adaptor) {
1170 // maximumf(x,x) -> x
1171 if (getLhs() == getRhs())
1172 return getRhs();
1173
1174 // maximumf(x, -inf) -> x
1175 if (matchPattern(adaptor.getRhs(), m_NegInfFloat()))
1176 return getLhs();
1177
1179 adaptor.getOperands(),
1180 [](const APFloat &a, const APFloat &b) { return llvm::maximum(a, b); });
1181}
1182
1183//===----------------------------------------------------------------------===//
1184// MaxNumFOp
1185//===----------------------------------------------------------------------===//
1186
1187OpFoldResult arith::MaxNumFOp::fold(FoldAdaptor adaptor) {
1188 // maxnumf(x,x) -> x
1189 if (getLhs() == getRhs())
1190 return getRhs();
1191
1192 // maxnumf(x, NaN) -> x
1193 if (matchPattern(adaptor.getRhs(), m_NaNFloat()))
1194 return getLhs();
1195
1196 return constFoldBinaryOp<FloatAttr>(adaptor.getOperands(), llvm::maxnum);
1197}
1198
1199//===----------------------------------------------------------------------===//
1200// MaxSIOp
1201//===----------------------------------------------------------------------===//
1202
1203OpFoldResult MaxSIOp::fold(FoldAdaptor adaptor) {
1204 // maxsi(x,x) -> x
1205 if (getLhs() == getRhs())
1206 return getRhs();
1207
1208 if (APInt intValue;
1209 matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue))) {
1210 // maxsi(x,MAX_INT) -> MAX_INT
1211 if (intValue.isMaxSignedValue())
1212 return getRhs();
1213 // maxsi(x, MIN_INT) -> x
1214 if (intValue.isMinSignedValue())
1215 return getLhs();
1216 }
1217
1218 return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1219 [](const APInt &a, const APInt &b) {
1220 return llvm::APIntOps::smax(a, b);
1221 });
1222}
1223
1224//===----------------------------------------------------------------------===//
1225// MaxUIOp
1226//===----------------------------------------------------------------------===//
1227
1228OpFoldResult MaxUIOp::fold(FoldAdaptor adaptor) {
1229 // maxui(x,x) -> x
1230 if (getLhs() == getRhs())
1231 return getRhs();
1232
1233 if (APInt intValue;
1234 matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue))) {
1235 // maxui(x,MAX_INT) -> MAX_INT
1236 if (intValue.isMaxValue())
1237 return getRhs();
1238 // maxui(x, MIN_INT) -> x
1239 if (intValue.isMinValue())
1240 return getLhs();
1241 }
1242
1243 return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1244 [](const APInt &a, const APInt &b) {
1245 return llvm::APIntOps::umax(a, b);
1246 });
1247}
1248
1249//===----------------------------------------------------------------------===//
1250// MinimumFOp
1251//===----------------------------------------------------------------------===//
1252
1253OpFoldResult arith::MinimumFOp::fold(FoldAdaptor adaptor) {
1254 // minimumf(x,x) -> x
1255 if (getLhs() == getRhs())
1256 return getRhs();
1257
1258 // minimumf(x, +inf) -> x
1259 if (matchPattern(adaptor.getRhs(), m_PosInfFloat()))
1260 return getLhs();
1261
1263 adaptor.getOperands(),
1264 [](const APFloat &a, const APFloat &b) { return llvm::minimum(a, b); });
1265}
1266
1267//===----------------------------------------------------------------------===//
1268// MinNumFOp
1269//===----------------------------------------------------------------------===//
1270
1271OpFoldResult arith::MinNumFOp::fold(FoldAdaptor adaptor) {
1272 // minnumf(x,x) -> x
1273 if (getLhs() == getRhs())
1274 return getRhs();
1275
1276 // minnumf(x, NaN) -> x
1277 if (matchPattern(adaptor.getRhs(), m_NaNFloat()))
1278 return getLhs();
1279
1281 adaptor.getOperands(),
1282 [](const APFloat &a, const APFloat &b) { return llvm::minnum(a, b); });
1283}
1284
1285//===----------------------------------------------------------------------===//
1286// MinSIOp
1287//===----------------------------------------------------------------------===//
1288
1289OpFoldResult MinSIOp::fold(FoldAdaptor adaptor) {
1290 // minsi(x,x) -> x
1291 if (getLhs() == getRhs())
1292 return getRhs();
1293
1294 if (APInt intValue;
1295 matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue))) {
1296 // minsi(x,MIN_INT) -> MIN_INT
1297 if (intValue.isMinSignedValue())
1298 return getRhs();
1299 // minsi(x, MAX_INT) -> x
1300 if (intValue.isMaxSignedValue())
1301 return getLhs();
1302 }
1303
1304 return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1305 [](const APInt &a, const APInt &b) {
1306 return llvm::APIntOps::smin(a, b);
1307 });
1308}
1309
1310//===----------------------------------------------------------------------===//
1311// MinUIOp
1312//===----------------------------------------------------------------------===//
1313
1314OpFoldResult MinUIOp::fold(FoldAdaptor adaptor) {
1315 // minui(x,x) -> x
1316 if (getLhs() == getRhs())
1317 return getRhs();
1318
1319 if (APInt intValue;
1320 matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue))) {
1321 // minui(x,MIN_INT) -> MIN_INT
1322 if (intValue.isMinValue())
1323 return getRhs();
1324 // minui(x, MAX_INT) -> x
1325 if (intValue.isMaxValue())
1326 return getLhs();
1327 }
1328
1329 return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1330 [](const APInt &a, const APInt &b) {
1331 return llvm::APIntOps::umin(a, b);
1332 });
1333}
1334
1335//===----------------------------------------------------------------------===//
1336// MulFOp
1337//===----------------------------------------------------------------------===//
1338
1339OpFoldResult arith::MulFOp::fold(FoldAdaptor adaptor) {
1340 // mulf(x, 1) -> x
1341 if (matchPattern(adaptor.getRhs(), m_OneFloat()))
1342 return getLhs();
1343
1344 if (arith::bitEnumContainsAll(getFastmath(), arith::FastMathFlags::nnan |
1345 arith::FastMathFlags::nsz)) {
1346 // mulf(x, 0) -> 0
1347 if (matchPattern(adaptor.getRhs(), m_AnyZeroFloat()))
1348 return getRhs();
1349 }
1350
1351 auto rm = getRoundingmode();
1353 adaptor.getOperands(), [rm](const APFloat &a, const APFloat &b) {
1354 APFloat result(a);
1355 result.multiply(b, convertArithRoundingModeToLLVMIR(rm));
1356 return result;
1357 });
1358}
1359
1360void arith::MulFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1361 MLIRContext *context) {
1362 patterns.add<MulFOfNegF>(context);
1363}
1364
1365//===----------------------------------------------------------------------===//
1366// DivFOp
1367//===----------------------------------------------------------------------===//
1368
1369OpFoldResult arith::DivFOp::fold(FoldAdaptor adaptor) {
1370 // divf(x, 1) -> x
1371 if (matchPattern(adaptor.getRhs(), m_OneFloat()))
1372 return getLhs();
1373
1374 auto rm = getRoundingmode();
1376 adaptor.getOperands(), [rm](const APFloat &a, const APFloat &b) {
1377 APFloat result(a);
1378 result.divide(b, convertArithRoundingModeToLLVMIR(rm));
1379 return result;
1380 });
1381}
1382
1383void arith::DivFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1384 MLIRContext *context) {
1385 patterns.add<DivFOfNegF>(context);
1386}
1387
1388//===----------------------------------------------------------------------===//
1389// RemFOp
1390//===----------------------------------------------------------------------===//
1391
1392OpFoldResult arith::RemFOp::fold(FoldAdaptor adaptor) {
1393 return constFoldBinaryOp<FloatAttr>(adaptor.getOperands(),
1394 [](const APFloat &a, const APFloat &b) {
1395 APFloat result(a);
1396 // APFloat::mod() offers the remainder
1397 // behavior we want, i.e. the result has
1398 // the sign of LHS operand.
1399 (void)result.mod(b);
1400 return result;
1401 });
1402}
1403
1404//===----------------------------------------------------------------------===//
1405// Utility functions for verifying cast ops
1406//===----------------------------------------------------------------------===//
1407
1408template <typename... Types>
1409using type_list = std::tuple<Types...> *;
1410
1411/// Returns a non-null type only if the provided type is one of the allowed
1412/// types or one of the allowed shaped types of the allowed types. Returns the
1413/// element type if a valid shaped type is provided.
1414template <typename... ShapedTypes, typename... ElementTypes>
1417 if (llvm::isa<ShapedType>(type) && !llvm::isa<ShapedTypes...>(type))
1418 return {};
1419
1420 auto underlyingType = getElementTypeOrSelf(type);
1421 if (!llvm::isa<ElementTypes...>(underlyingType))
1422 return {};
1423
1424 return underlyingType;
1425}
1426
1427/// Get allowed underlying types for vectors and tensors.
1428template <typename... ElementTypes>
1433
1434/// Get allowed underlying types for vectors, tensors, and memrefs.
1435template <typename... ElementTypes>
1441
1442/// Return false if both types are ranked tensor with mismatching encoding.
1443static bool hasSameEncoding(Type typeA, Type typeB) {
1444 auto rankedTensorA = dyn_cast<RankedTensorType>(typeA);
1445 auto rankedTensorB = dyn_cast<RankedTensorType>(typeB);
1446 if (!rankedTensorA || !rankedTensorB)
1447 return true;
1448 return rankedTensorA.getEncoding() == rankedTensorB.getEncoding();
1449}
1450
1452 if (inputs.size() != 1 || outputs.size() != 1)
1453 return false;
1454 if (!hasSameEncoding(inputs.front(), outputs.front()))
1455 return false;
1456 return succeeded(verifyCompatibleShapes(inputs.front(), outputs.front()));
1457}
1458
1459//===----------------------------------------------------------------------===//
1460// Verifiers for integer and floating point extension/truncation ops
1461//===----------------------------------------------------------------------===//
1462
1463// Extend ops can only extend to a wider type.
1464template <typename ValType, typename Op>
1465static LogicalResult verifyExtOp(Op op) {
1466 Type srcType = getElementTypeOrSelf(op.getIn().getType());
1467 Type dstType = getElementTypeOrSelf(op.getType());
1468
1469 if (llvm::cast<ValType>(srcType).getWidth() >=
1470 llvm::cast<ValType>(dstType).getWidth())
1471 return op.emitError("result type ")
1472 << dstType << " must be wider than operand type " << srcType;
1473
1474 return success();
1475}
1476
1477// Truncate ops can only truncate to a shorter type.
1478template <typename ValType, typename Op>
1479static LogicalResult verifyTruncateOp(Op op) {
1480 Type srcType = getElementTypeOrSelf(op.getIn().getType());
1481 Type dstType = getElementTypeOrSelf(op.getType());
1482
1483 if (llvm::cast<ValType>(srcType).getWidth() <=
1484 llvm::cast<ValType>(dstType).getWidth())
1485 return op.emitError("result type ")
1486 << dstType << " must be shorter than operand type " << srcType;
1487
1488 return success();
1489}
1490
1491/// Validate a cast that changes the width of a type.
1492template <template <typename> class WidthComparator, typename... ElementTypes>
1493static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs) {
1494 if (!areValidCastInputsAndOutputs(inputs, outputs))
1495 return false;
1496
1497 auto srcType = getTypeIfLike<ElementTypes...>(inputs.front());
1498 auto dstType = getTypeIfLike<ElementTypes...>(outputs.front());
1499 if (!srcType || !dstType)
1500 return false;
1501
1502 return WidthComparator<unsigned>()(dstType.getIntOrFloatBitWidth(),
1503 srcType.getIntOrFloatBitWidth());
1504}
1505
1506/// Attempts to convert `sourceValue` to an APFloat value with
1507/// `targetSemantics` and `roundingMode`, without any information loss.
1508static FailureOr<APFloat>
1509convertFloatValue(APFloat sourceValue,
1510 const llvm::fltSemantics &targetSemantics,
1511 llvm::RoundingMode roundingMode = kDefaultRoundingMode) {
1512 // Reject special values that are not representable in the target type before
1513 // calling APFloat::convert, which would llvm_unreachable on them.
1514 using fltNonfiniteBehavior = llvm::fltNonfiniteBehavior;
1515 if (sourceValue.isInfinity() &&
1516 (targetSemantics.nonFiniteBehavior == fltNonfiniteBehavior::NanOnly ||
1517 targetSemantics.nonFiniteBehavior == fltNonfiniteBehavior::FiniteOnly))
1518 return failure();
1519 if (sourceValue.isNaN() &&
1520 targetSemantics.nonFiniteBehavior == fltNonfiniteBehavior::FiniteOnly)
1521 return failure();
1522
1523 bool losesInfo = false;
1524 auto status = sourceValue.convert(targetSemantics, roundingMode, &losesInfo);
1525 if (losesInfo || status != APFloat::opOK)
1526 return failure();
1527
1528 return sourceValue;
1529}
1530
1531//===----------------------------------------------------------------------===//
1532// ExtUIOp
1533//===----------------------------------------------------------------------===//
1534
1535OpFoldResult arith::ExtUIOp::fold(FoldAdaptor adaptor) {
1536 if (auto lhs = getIn().getDefiningOp<ExtUIOp>()) {
1537 getInMutable().assign(lhs.getIn());
1538 return getResult();
1539 }
1540
1541 Type resType = getElementTypeOrSelf(getType());
1542 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1544 adaptor.getOperands(), getType(),
1545 [bitWidth](const APInt &a, bool &castStatus) {
1546 return a.zext(bitWidth);
1547 });
1548}
1549
1550bool arith::ExtUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1552}
1553
1554LogicalResult arith::ExtUIOp::verify() {
1555 return verifyExtOp<IntegerType>(*this);
1556}
1557
1558//===----------------------------------------------------------------------===//
1559// ExtSIOp
1560//===----------------------------------------------------------------------===//
1561
1562OpFoldResult arith::ExtSIOp::fold(FoldAdaptor adaptor) {
1563 if (auto lhs = getIn().getDefiningOp<ExtSIOp>()) {
1564 getInMutable().assign(lhs.getIn());
1565 return getResult();
1566 }
1567
1568 Type resType = getElementTypeOrSelf(getType());
1569 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1571 adaptor.getOperands(), getType(),
1572 [bitWidth](const APInt &a, bool &castStatus) {
1573 return a.sext(bitWidth);
1574 });
1575}
1576
1577bool arith::ExtSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1579}
1580
1581void arith::ExtSIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1582 MLIRContext *context) {
1583 patterns.add<ExtSIOfExtUI>(context);
1584}
1585
1586LogicalResult arith::ExtSIOp::verify() {
1587 return verifyExtOp<IntegerType>(*this);
1588}
1589
1590//===----------------------------------------------------------------------===//
1591// ExtFOp
1592//===----------------------------------------------------------------------===//
1593
1594/// Fold extension of float constants when there is no information loss due the
1595/// difference in fp semantics.
1596OpFoldResult arith::ExtFOp::fold(FoldAdaptor adaptor) {
1597 if (auto truncFOp = getOperand().getDefiningOp<TruncFOp>()) {
1598 if (truncFOp.getOperand().getType() == getType()) {
1599 arith::FastMathFlags truncFMF =
1600 truncFOp.getFastmath().value_or(arith::FastMathFlags::none);
1601 bool isTruncContract =
1602 bitEnumContainsAll(truncFMF, arith::FastMathFlags::contract);
1603 arith::FastMathFlags extFMF =
1604 getFastmath().value_or(arith::FastMathFlags::none);
1605 bool isExtContract =
1606 bitEnumContainsAll(extFMF, arith::FastMathFlags::contract);
1607 if (isTruncContract && isExtContract) {
1608 return truncFOp.getOperand();
1609 }
1610 }
1611 }
1612
1613 auto resElemType = cast<FloatType>(getElementTypeOrSelf(getType()));
1614 const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
1616 adaptor.getOperands(), getType(),
1617 [&targetSemantics](const APFloat &a, bool &castStatus) {
1618 FailureOr<APFloat> result = convertFloatValue(a, targetSemantics);
1619 if (failed(result)) {
1620 castStatus = false;
1621 return a;
1622 }
1623 return *result;
1624 });
1625}
1626
1627bool arith::ExtFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1628 return checkWidthChangeCast<std::greater, FloatType>(inputs, outputs);
1629}
1630
1631LogicalResult arith::ExtFOp::verify() { return verifyExtOp<FloatType>(*this); }
1632
1633//===----------------------------------------------------------------------===//
1634// ScalingExtFOp
1635//===----------------------------------------------------------------------===//
1636
1637bool arith::ScalingExtFOp::areCastCompatible(TypeRange inputs,
1638 TypeRange outputs) {
1639 return checkWidthChangeCast<std::greater, FloatType>(inputs.front(), outputs);
1640}
1641
1642LogicalResult arith::ScalingExtFOp::verify() {
1643 return verifyExtOp<FloatType>(*this);
1644}
1645
1646//===----------------------------------------------------------------------===//
1647// TruncIOp
1648//===----------------------------------------------------------------------===//
1649
1650OpFoldResult arith::TruncIOp::fold(FoldAdaptor adaptor) {
1651 if (matchPattern(getOperand(), m_Op<arith::ExtUIOp>()) ||
1652 matchPattern(getOperand(), m_Op<arith::ExtSIOp>())) {
1653 Value src = getOperand().getDefiningOp()->getOperand(0);
1654 Type srcType = getElementTypeOrSelf(src.getType());
1655 Type dstType = getElementTypeOrSelf(getType());
1656 // trunci(zexti(a)) -> trunci(a)
1657 // trunci(sexti(a)) -> trunci(a)
1658 if (llvm::cast<IntegerType>(srcType).getWidth() >
1659 llvm::cast<IntegerType>(dstType).getWidth()) {
1660 setOperand(src);
1661 return getResult();
1662 }
1663
1664 // trunci(zexti(a)) -> a
1665 // trunci(sexti(a)) -> a
1666 if (srcType == dstType)
1667 return src;
1668 }
1669
1670 // trunci(trunci(a)) -> trunci(a))
1671 if (matchPattern(getOperand(), m_Op<arith::TruncIOp>())) {
1672 setOperand(getOperand().getDefiningOp()->getOperand(0));
1673 return getResult();
1674 }
1675
1676 Type resType = getElementTypeOrSelf(getType());
1677 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1679 adaptor.getOperands(), getType(),
1680 [bitWidth](const APInt &a, bool &castStatus) {
1681 return a.trunc(bitWidth);
1682 });
1683}
1684
1685bool arith::TruncIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1686 return checkWidthChangeCast<std::less, IntegerType>(inputs, outputs);
1687}
1688
1689void arith::TruncIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1690 MLIRContext *context) {
1691 patterns
1692 .add<TruncIExtSIToExtSI, TruncIExtUIToExtUI, TruncIShrSIToTrunciShrUI>(
1693 context);
1694}
1695
1696LogicalResult arith::TruncIOp::verify() {
1697 return verifyTruncateOp<IntegerType>(*this);
1698}
1699
1700//===----------------------------------------------------------------------===//
1701// TruncFOp
1702//===----------------------------------------------------------------------===//
1703
1704/// Perform safe const propagation for truncf, i.e., only propagate if FP value
1705/// can be represented without precision loss.
1706OpFoldResult arith::TruncFOp::fold(FoldAdaptor adaptor) {
1707 auto resElemType = cast<FloatType>(getElementTypeOrSelf(getType()));
1708 if (auto extOp = getOperand().getDefiningOp<arith::ExtFOp>()) {
1709 Value src = extOp.getIn();
1710 auto srcType = cast<FloatType>(getElementTypeOrSelf(src.getType()));
1711 auto intermediateType =
1712 cast<FloatType>(getElementTypeOrSelf(extOp.getType()));
1713 // Check if the srcType is representable in the intermediateType.
1714 if (llvm::APFloatBase::isRepresentableBy(
1715 srcType.getFloatSemantics(),
1716 intermediateType.getFloatSemantics())) {
1717 // truncf(extf(a)) -> truncf(a)
1718 if (srcType.getWidth() > resElemType.getWidth()) {
1719 setOperand(src);
1720 return getResult();
1721 }
1722
1723 // truncf(extf(a)) -> a
1724 if (srcType == resElemType)
1725 return src;
1726 }
1727 }
1728
1729 const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
1731 adaptor.getOperands(), getType(),
1732 [this, &targetSemantics](const APFloat &a, bool &castStatus) {
1733 llvm::RoundingMode llvmRoundingMode =
1734 convertArithRoundingModeToLLVMIR(getRoundingmode());
1735 FailureOr<APFloat> result =
1736 convertFloatValue(a, targetSemantics, llvmRoundingMode);
1737 if (failed(result)) {
1738 castStatus = false;
1739 return a;
1740 }
1741 return *result;
1742 });
1743}
1744
1745void arith::TruncFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1746 MLIRContext *context) {
1747 patterns.add<TruncFSIToFPToSIToFP, TruncFUIToFPToUIToFP>(context);
1748}
1749
1750bool arith::TruncFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1751 return checkWidthChangeCast<std::less, FloatType>(inputs, outputs);
1752}
1753
1754LogicalResult arith::TruncFOp::verify() {
1755 return verifyTruncateOp<FloatType>(*this);
1756}
1757
1758//===----------------------------------------------------------------------===//
1759// ConvertFOp
1760//===----------------------------------------------------------------------===//
1761
1762OpFoldResult arith::ConvertFOp::fold(FoldAdaptor adaptor) {
1763 auto resElemType = cast<FloatType>(getElementTypeOrSelf(getType()));
1764 const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
1766 adaptor.getOperands(), getType(),
1767 [this, &targetSemantics](const APFloat &a, bool &castStatus) {
1768 llvm::RoundingMode llvmRoundingMode =
1769 convertArithRoundingModeToLLVMIR(getRoundingmode());
1770 FailureOr<APFloat> result =
1771 convertFloatValue(a, targetSemantics, llvmRoundingMode);
1772 if (failed(result)) {
1773 castStatus = false;
1774 return a;
1775 }
1776 return *result;
1777 });
1778}
1779
1780bool arith::ConvertFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1781 if (!areValidCastInputsAndOutputs(inputs, outputs))
1782 return false;
1783 auto srcType = getTypeIfLike<FloatType>(inputs.front());
1784 auto dstType = getTypeIfLike<FloatType>(outputs.front());
1785 if (!srcType || !dstType)
1786 return false;
1787 return srcType != dstType &&
1788 srcType.getIntOrFloatBitWidth() == dstType.getIntOrFloatBitWidth();
1789}
1790
1791LogicalResult arith::ConvertFOp::verify() {
1792 auto srcType = cast<FloatType>(getElementTypeOrSelf(getIn().getType()));
1793 auto dstType = cast<FloatType>(getElementTypeOrSelf(getType()));
1794 if (srcType == dstType)
1795 return emitError("result element type ")
1796 << dstType << " must be different from operand element type "
1797 << srcType;
1798 if (srcType.getWidth() != dstType.getWidth())
1799 return emitError("result element type ")
1800 << dstType << " must have the same bitwidth as operand element type "
1801 << srcType;
1802 return success();
1803}
1804
1805//===----------------------------------------------------------------------===//
1806// ScalingTruncFOp
1807//===----------------------------------------------------------------------===//
1808
1809bool arith::ScalingTruncFOp::areCastCompatible(TypeRange inputs,
1810 TypeRange outputs) {
1811 return checkWidthChangeCast<std::less, FloatType>(inputs.front(), outputs);
1812}
1813
1814LogicalResult arith::ScalingTruncFOp::verify() {
1815 return verifyTruncateOp<FloatType>(*this);
1816}
1817
1818//===----------------------------------------------------------------------===//
1819// AndIOp
1820//===----------------------------------------------------------------------===//
1821
1822void arith::AndIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1823 MLIRContext *context) {
1824 patterns.add<AndOfExtUI, AndOfExtSI>(context);
1825}
1826
1827//===----------------------------------------------------------------------===//
1828// OrIOp
1829//===----------------------------------------------------------------------===//
1830
1831void arith::OrIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1832 MLIRContext *context) {
1833 patterns.add<OrOfExtUI, OrOfExtSI>(context);
1834}
1835
1836//===----------------------------------------------------------------------===//
1837// Verifiers for casts between integers and floats.
1838//===----------------------------------------------------------------------===//
1839
1840template <typename From, typename To>
1841static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs) {
1842 if (!areValidCastInputsAndOutputs(inputs, outputs))
1843 return false;
1844
1845 auto srcType = getTypeIfLike<From>(inputs.front());
1846 auto dstType = getTypeIfLike<To>(outputs.back());
1847
1848 return srcType && dstType;
1849}
1850
1851//===----------------------------------------------------------------------===//
1852// UIToFPOp
1853//===----------------------------------------------------------------------===//
1854
1855bool arith::UIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1856 return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
1857}
1858
1859OpFoldResult arith::UIToFPOp::fold(FoldAdaptor adaptor) {
1860 Type resEleType = getElementTypeOrSelf(getType());
1862 adaptor.getOperands(), getType(),
1863 [&resEleType](const APInt &a, bool &castStatus) {
1864 FloatType floatTy = llvm::cast<FloatType>(resEleType);
1865 APFloat apf(floatTy.getFloatSemantics(),
1866 APInt::getZero(floatTy.getWidth()));
1867 apf.convertFromAPInt(a, /*IsSigned=*/false,
1868 APFloat::rmNearestTiesToEven);
1869 return apf;
1870 });
1871}
1872
1873void arith::UIToFPOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1874 MLIRContext *context) {
1875 patterns.add<UIToFPOfExtUI>(context);
1876}
1877
1878//===----------------------------------------------------------------------===//
1879// SIToFPOp
1880//===----------------------------------------------------------------------===//
1881
1882bool arith::SIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1883 return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
1884}
1885
1886OpFoldResult arith::SIToFPOp::fold(FoldAdaptor adaptor) {
1887 Type resEleType = getElementTypeOrSelf(getType());
1889 adaptor.getOperands(), getType(),
1890 [&resEleType](const APInt &a, bool &castStatus) {
1891 FloatType floatTy = llvm::cast<FloatType>(resEleType);
1892 APFloat apf(floatTy.getFloatSemantics(),
1893 APInt::getZero(floatTy.getWidth()));
1894 apf.convertFromAPInt(a, /*IsSigned=*/true,
1895 APFloat::rmNearestTiesToEven);
1896 return apf;
1897 });
1898}
1899
1900void arith::SIToFPOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1901 MLIRContext *context) {
1902 patterns.add<SIToFPOfExtSI, SIToFPOfExtUI>(context);
1903}
1904
1905//===----------------------------------------------------------------------===//
1906// FPToUIOp
1907//===----------------------------------------------------------------------===//
1908
1909bool arith::FPToUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1910 return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
1911}
1912
1913OpFoldResult arith::FPToUIOp::fold(FoldAdaptor adaptor) {
1914 Type resType = getElementTypeOrSelf(getType());
1915 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1917 adaptor.getOperands(), getType(),
1918 [&bitWidth](const APFloat &a, bool &castStatus) {
1919 bool ignored;
1920 APSInt api(bitWidth, /*isUnsigned=*/true);
1921 castStatus = APFloat::opInvalidOp !=
1922 a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
1923 return api;
1924 });
1925}
1926
1927//===----------------------------------------------------------------------===//
1928// FPToSIOp
1929//===----------------------------------------------------------------------===//
1930
1931bool arith::FPToSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1932 return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
1933}
1934
1935OpFoldResult arith::FPToSIOp::fold(FoldAdaptor adaptor) {
1936 Type resType = getElementTypeOrSelf(getType());
1937 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1939 adaptor.getOperands(), getType(),
1940 [&bitWidth](const APFloat &a, bool &castStatus) {
1941 bool ignored;
1942 APSInt api(bitWidth, /*isUnsigned=*/false);
1943 castStatus = APFloat::opInvalidOp !=
1944 a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
1945 return api;
1946 });
1947}
1948
1949//===----------------------------------------------------------------------===//
1950// IndexCastOp
1951//===----------------------------------------------------------------------===//
1952
1953/// Return the bit-width of \p t for the purpose of index_cast width checks.
1954/// For vector types use the element type; index maps to its internal storage
1955/// width (64 on all current targets).
1956static unsigned getIndexCastWidth(Type t) {
1957 if (auto intTy = dyn_cast<IntegerType>(getElementTypeOrSelf(t)))
1958 return intTy.getWidth();
1959 return IndexType::kInternalStorageBitWidth;
1960}
1961
1962static bool areIndexCastCompatible(TypeRange inputs, TypeRange outputs) {
1963 if (!areValidCastInputsAndOutputs(inputs, outputs))
1964 return false;
1965
1966 auto srcType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(inputs.front());
1967 auto dstType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(outputs.front());
1968 if (!srcType || !dstType)
1969 return false;
1970
1971 return (srcType.isIndex() && dstType.isSignlessInteger()) ||
1972 (srcType.isSignlessInteger() && dstType.isIndex());
1973}
1974
1975bool arith::IndexCastOp::areCastCompatible(TypeRange inputs,
1976 TypeRange outputs) {
1977 return areIndexCastCompatible(inputs, outputs);
1978}
1979
1980OpFoldResult arith::IndexCastOp::fold(FoldAdaptor adaptor) {
1981 // index_cast(constant) -> constant
1982 unsigned resultBitwidth = 64; // Default for index integer attributes.
1983 if (auto intTy = dyn_cast<IntegerType>(getElementTypeOrSelf(getType())))
1984 resultBitwidth = intTy.getWidth();
1985
1986 if (auto foldResult = constFoldCastOp<IntegerAttr, IntegerAttr>(
1987 adaptor.getOperands(), getType(),
1988 [resultBitwidth](const APInt &a, bool & /*castStatus*/) {
1989 return a.sextOrTrunc(resultBitwidth);
1990 }))
1991 return foldResult;
1992
1993 // index_cast(index_cast(x : A) : B) : A -> x, but only when B is at least
1994 // as wide as A. If B is narrower, the inner cast truncates and the outer
1995 // cast sign-extends, so the round-trip is lossy.
1996 if (auto inner = getOperand().getDefiningOp<arith::IndexCastOp>()) {
1997 Value x = inner.getOperand();
1998 if (x.getType() == getType()) {
1999 if (getIndexCastWidth(inner.getType()) >= getIndexCastWidth(x.getType()))
2000 return x;
2001 }
2002 }
2003 return {};
2004}
2005
2006void arith::IndexCastOp::getCanonicalizationPatterns(
2007 RewritePatternSet &patterns, MLIRContext *context) {
2008 patterns.add<IndexCastOfExtSI>(context);
2009}
2010
2011//===----------------------------------------------------------------------===//
2012// IndexCastUIOp
2013//===----------------------------------------------------------------------===//
2014
2015bool arith::IndexCastUIOp::areCastCompatible(TypeRange inputs,
2016 TypeRange outputs) {
2017 return areIndexCastCompatible(inputs, outputs);
2018}
2019
2020OpFoldResult arith::IndexCastUIOp::fold(FoldAdaptor adaptor) {
2021 // index_castui(constant) -> constant
2022 unsigned resultBitwidth = 64; // Default for index integer attributes.
2023 if (auto intTy = dyn_cast<IntegerType>(getElementTypeOrSelf(getType())))
2024 resultBitwidth = intTy.getWidth();
2025
2026 if (auto foldResult = constFoldCastOp<IntegerAttr, IntegerAttr>(
2027 adaptor.getOperands(), getType(),
2028 [resultBitwidth](const APInt &a, bool & /*castStatus*/) {
2029 return a.zextOrTrunc(resultBitwidth);
2030 }))
2031 return foldResult;
2032
2033 // index_castui(index_castui(x : A) : B) : A -> x, but only when B is at
2034 // least as wide as A. If B is narrower, the inner cast truncates and the
2035 // outer cast zero-extends, so the round-trip is lossy.
2036 if (auto inner = getOperand().getDefiningOp<arith::IndexCastUIOp>()) {
2037 Value x = inner.getOperand();
2038 if (x.getType() == getType()) {
2039 if (getIndexCastWidth(inner.getType()) >= getIndexCastWidth(x.getType()))
2040 return x;
2041 }
2042 }
2043 return {};
2044}
2045
2046void arith::IndexCastUIOp::getCanonicalizationPatterns(
2047 RewritePatternSet &patterns, MLIRContext *context) {
2048 patterns.add<IndexCastUIOfExtUI>(context);
2049}
2050
2051//===----------------------------------------------------------------------===//
2052// BitcastOp
2053//===----------------------------------------------------------------------===//
2054
2055bool arith::BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
2056 if (!areValidCastInputsAndOutputs(inputs, outputs))
2057 return false;
2058
2059 auto srcType = getTypeIfLikeOrMemRef<IntegerType, FloatType>(inputs.front());
2060 auto dstType = getTypeIfLikeOrMemRef<IntegerType, FloatType>(outputs.front());
2061 if (!srcType || !dstType)
2062 return false;
2063
2064 return srcType.getIntOrFloatBitWidth() == dstType.getIntOrFloatBitWidth();
2065}
2066
2067OpFoldResult arith::BitcastOp::fold(FoldAdaptor adaptor) {
2068 auto resType = getType();
2069 auto operand = adaptor.getIn();
2070 if (!operand)
2071 return {};
2072
2073 /// Bitcast dense elements.
2074 if (auto denseAttr = dyn_cast_or_null<DenseElementsAttr>(operand))
2075 return denseAttr.bitcast(llvm::cast<ShapedType>(resType).getElementType());
2076 /// Other shaped types unhandled.
2077 if (llvm::isa<ShapedType>(resType))
2078 return {};
2079
2080 /// Bitcast poison.
2081 if (matchPattern(operand, ub::m_Poison()))
2082 return ub::PoisonAttr::get(getContext());
2083
2084 /// Bitcast integer or float to integer or float.
2085 APInt bits = llvm::isa<FloatAttr>(operand)
2086 ? llvm::cast<FloatAttr>(operand).getValue().bitcastToAPInt()
2087 : llvm::cast<IntegerAttr>(operand).getValue();
2088 assert(resType.getIntOrFloatBitWidth() == bits.getBitWidth() &&
2089 "trying to fold on broken IR: operands have incompatible types");
2090
2091 if (auto resFloatType = dyn_cast<FloatType>(resType))
2092 return FloatAttr::get(resType,
2093 APFloat(resFloatType.getFloatSemantics(), bits));
2094 return IntegerAttr::get(resType, bits);
2095}
2096
2097void arith::BitcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
2098 MLIRContext *context) {
2099 patterns.add<BitcastOfBitcast>(context);
2100}
2101
2102//===----------------------------------------------------------------------===//
2103// CmpIOp
2104//===----------------------------------------------------------------------===//
2105
2106/// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer
2107/// comparison predicates.
2108bool mlir::arith::applyCmpPredicate(arith::CmpIPredicate predicate,
2109 const APInt &lhs, const APInt &rhs) {
2110 switch (predicate) {
2111 case arith::CmpIPredicate::eq:
2112 return lhs.eq(rhs);
2113 case arith::CmpIPredicate::ne:
2114 return lhs.ne(rhs);
2115 case arith::CmpIPredicate::slt:
2116 return lhs.slt(rhs);
2117 case arith::CmpIPredicate::sle:
2118 return lhs.sle(rhs);
2119 case arith::CmpIPredicate::sgt:
2120 return lhs.sgt(rhs);
2121 case arith::CmpIPredicate::sge:
2122 return lhs.sge(rhs);
2123 case arith::CmpIPredicate::ult:
2124 return lhs.ult(rhs);
2125 case arith::CmpIPredicate::ule:
2126 return lhs.ule(rhs);
2127 case arith::CmpIPredicate::ugt:
2128 return lhs.ugt(rhs);
2129 case arith::CmpIPredicate::uge:
2130 return lhs.uge(rhs);
2131 }
2132 llvm_unreachable("unknown cmpi predicate kind");
2133}
2134
2135/// Returns true if the predicate is true for two equal operands.
2136static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate) {
2137 switch (predicate) {
2138 case arith::CmpIPredicate::eq:
2139 case arith::CmpIPredicate::sle:
2140 case arith::CmpIPredicate::sge:
2141 case arith::CmpIPredicate::ule:
2142 case arith::CmpIPredicate::uge:
2143 return true;
2144 case arith::CmpIPredicate::ne:
2145 case arith::CmpIPredicate::slt:
2146 case arith::CmpIPredicate::sgt:
2147 case arith::CmpIPredicate::ult:
2148 case arith::CmpIPredicate::ugt:
2149 return false;
2150 }
2151 llvm_unreachable("unknown cmpi predicate kind");
2152}
2153
2154static std::optional<int64_t> getIntegerWidth(Type t) {
2155 if (auto intType = dyn_cast<IntegerType>(t)) {
2156 return intType.getWidth();
2157 }
2158 if (auto vectorIntType = dyn_cast<VectorType>(t)) {
2159 return llvm::cast<IntegerType>(vectorIntType.getElementType()).getWidth();
2160 }
2161 return std::nullopt;
2162}
2163
2164OpFoldResult arith::CmpIOp::fold(FoldAdaptor adaptor) {
2165 // cmpi(pred, x, x)
2166 if (getLhs() == getRhs()) {
2167 auto val = applyCmpPredicateToEqualOperands(getPredicate());
2168 return getBoolAttribute(getType(), val);
2169 }
2170
2171 if (matchPattern(adaptor.getRhs(), m_Zero())) {
2172 if (auto extOp = getLhs().getDefiningOp<ExtSIOp>()) {
2173 // extsi(%x : i1 -> iN) != 0 -> %x
2174 std::optional<int64_t> integerWidth =
2175 getIntegerWidth(extOp.getOperand().getType());
2176 if (integerWidth && integerWidth.value() == 1 &&
2177 getPredicate() == arith::CmpIPredicate::ne)
2178 return extOp.getOperand();
2179 }
2180 if (auto extOp = getLhs().getDefiningOp<ExtUIOp>()) {
2181 // extui(%x : i1 -> iN) != 0 -> %x
2182 std::optional<int64_t> integerWidth =
2183 getIntegerWidth(extOp.getOperand().getType());
2184 if (integerWidth && integerWidth.value() == 1 &&
2185 getPredicate() == arith::CmpIPredicate::ne)
2186 return extOp.getOperand();
2187 }
2188
2189 // arith.cmpi ne, %val, %zero : i1 -> %val
2190 if (getElementTypeOrSelf(getLhs().getType()).isInteger(1) &&
2191 getPredicate() == arith::CmpIPredicate::ne)
2192 return getLhs();
2193 }
2194
2195 if (matchPattern(adaptor.getRhs(), m_One())) {
2196 // arith.cmpi eq, %val, %one : i1 -> %val
2197 if (getElementTypeOrSelf(getLhs().getType()).isInteger(1) &&
2198 getPredicate() == arith::CmpIPredicate::eq)
2199 return getLhs();
2200 }
2201
2202 // Move constant to the right side.
2203 if (adaptor.getLhs() && !adaptor.getRhs()) {
2204 // Do not use invertPredicate, as it will change eq to ne and vice versa.
2205 using Pred = CmpIPredicate;
2206 const std::pair<Pred, Pred> invPreds[] = {
2207 {Pred::slt, Pred::sgt}, {Pred::sgt, Pred::slt}, {Pred::sle, Pred::sge},
2208 {Pred::sge, Pred::sle}, {Pred::ult, Pred::ugt}, {Pred::ugt, Pred::ult},
2209 {Pred::ule, Pred::uge}, {Pred::uge, Pred::ule}, {Pred::eq, Pred::eq},
2210 {Pred::ne, Pred::ne},
2211 };
2212 Pred origPred = getPredicate();
2213 for (auto pred : invPreds) {
2214 if (origPred == pred.first) {
2215 setPredicate(pred.second);
2216 Value lhs = getLhs();
2217 Value rhs = getRhs();
2218 getLhsMutable().assign(rhs);
2219 getRhsMutable().assign(lhs);
2220 return getResult();
2221 }
2222 }
2223 llvm_unreachable("unknown cmpi predicate kind");
2224 }
2225
2226 // We are moving constants to the right side; So if lhs is constant rhs is
2227 // guaranteed to be a constant.
2228 if (auto lhs = dyn_cast_if_present<TypedAttr>(adaptor.getLhs())) {
2230 adaptor.getOperands(), getI1SameShape(lhs.getType()),
2231 [pred = getPredicate()](const APInt &lhs, const APInt &rhs) {
2232 return APInt(1,
2233 static_cast<int64_t>(applyCmpPredicate(pred, lhs, rhs)));
2234 });
2235 }
2236
2237 return {};
2238}
2239
2240void arith::CmpIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
2241 MLIRContext *context) {
2242 patterns.insert<CmpIExtSI, CmpIExtUI>(context);
2243}
2244
2245//===----------------------------------------------------------------------===//
2246// CmpFOp
2247//===----------------------------------------------------------------------===//
2248
2249/// Compute `lhs` `pred` `rhs`, where `pred` is one of the known floating point
2250/// comparison predicates.
2251bool mlir::arith::applyCmpPredicate(arith::CmpFPredicate predicate,
2252 const APFloat &lhs, const APFloat &rhs) {
2253 auto cmpResult = lhs.compare(rhs);
2254 switch (predicate) {
2255 case arith::CmpFPredicate::AlwaysFalse:
2256 return false;
2257 case arith::CmpFPredicate::OEQ:
2258 return cmpResult == APFloat::cmpEqual;
2259 case arith::CmpFPredicate::OGT:
2260 return cmpResult == APFloat::cmpGreaterThan;
2261 case arith::CmpFPredicate::OGE:
2262 return cmpResult == APFloat::cmpGreaterThan ||
2263 cmpResult == APFloat::cmpEqual;
2264 case arith::CmpFPredicate::OLT:
2265 return cmpResult == APFloat::cmpLessThan;
2266 case arith::CmpFPredicate::OLE:
2267 return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
2268 case arith::CmpFPredicate::ONE:
2269 return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual;
2270 case arith::CmpFPredicate::ORD:
2271 return cmpResult != APFloat::cmpUnordered;
2272 case arith::CmpFPredicate::UEQ:
2273 return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual;
2274 case arith::CmpFPredicate::UGT:
2275 return cmpResult == APFloat::cmpUnordered ||
2276 cmpResult == APFloat::cmpGreaterThan;
2277 case arith::CmpFPredicate::UGE:
2278 return cmpResult == APFloat::cmpUnordered ||
2279 cmpResult == APFloat::cmpGreaterThan ||
2280 cmpResult == APFloat::cmpEqual;
2281 case arith::CmpFPredicate::ULT:
2282 return cmpResult == APFloat::cmpUnordered ||
2283 cmpResult == APFloat::cmpLessThan;
2284 case arith::CmpFPredicate::ULE:
2285 return cmpResult == APFloat::cmpUnordered ||
2286 cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
2287 case arith::CmpFPredicate::UNE:
2288 return cmpResult != APFloat::cmpEqual;
2289 case arith::CmpFPredicate::UNO:
2290 return cmpResult == APFloat::cmpUnordered;
2291 case arith::CmpFPredicate::AlwaysTrue:
2292 return true;
2293 }
2294 llvm_unreachable("unknown cmpf predicate kind");
2295}
2296
2297OpFoldResult arith::CmpFOp::fold(FoldAdaptor adaptor) {
2298 auto lhs = dyn_cast_if_present<FloatAttr>(adaptor.getLhs());
2299 auto rhs = dyn_cast_if_present<FloatAttr>(adaptor.getRhs());
2300
2301 // If one operand is NaN, making them both NaN does not change the result.
2302 if (lhs && lhs.getValue().isNaN())
2303 rhs = lhs;
2304 if (rhs && rhs.getValue().isNaN())
2305 lhs = rhs;
2306
2307 if (!lhs || !rhs)
2308 return {};
2309
2310 auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
2311 return BoolAttr::get(getContext(), val);
2312}
2313
2314class CmpFIntToFPConst final : public OpRewritePattern<CmpFOp> {
2315public:
2316 using Base::Base;
2317
2318 static CmpIPredicate convertToIntegerPredicate(CmpFPredicate pred,
2319 bool isUnsigned) {
2320 using namespace arith;
2321 switch (pred) {
2322 case CmpFPredicate::UEQ:
2323 case CmpFPredicate::OEQ:
2324 return CmpIPredicate::eq;
2325 case CmpFPredicate::UGT:
2326 case CmpFPredicate::OGT:
2327 return isUnsigned ? CmpIPredicate::ugt : CmpIPredicate::sgt;
2328 case CmpFPredicate::UGE:
2329 case CmpFPredicate::OGE:
2330 return isUnsigned ? CmpIPredicate::uge : CmpIPredicate::sge;
2331 case CmpFPredicate::ULT:
2332 case CmpFPredicate::OLT:
2333 return isUnsigned ? CmpIPredicate::ult : CmpIPredicate::slt;
2334 case CmpFPredicate::ULE:
2335 case CmpFPredicate::OLE:
2336 return isUnsigned ? CmpIPredicate::ule : CmpIPredicate::sle;
2337 case CmpFPredicate::UNE:
2338 case CmpFPredicate::ONE:
2339 return CmpIPredicate::ne;
2340 default:
2341 llvm_unreachable("Unexpected predicate!");
2342 }
2343 }
2344
2345 LogicalResult matchAndRewrite(CmpFOp op,
2346 PatternRewriter &rewriter) const override {
2347 FloatAttr flt;
2348 if (!matchPattern(op.getRhs(), m_Constant(&flt)))
2349 return failure();
2350
2351 const APFloat &rhs = flt.getValue();
2352
2353 // Don't attempt to fold a nan.
2354 if (rhs.isNaN())
2355 return failure();
2356
2357 // Get the width of the mantissa. We don't want to hack on conversions that
2358 // might lose information from the integer, e.g. "i64 -> float"
2359 FloatType floatTy = llvm::cast<FloatType>(op.getRhs().getType());
2360 int mantissaWidth = floatTy.getFPMantissaWidth();
2361 if (mantissaWidth <= 0)
2362 return failure();
2363
2364 bool isUnsigned;
2365 Value intVal;
2366
2367 if (auto si = op.getLhs().getDefiningOp<SIToFPOp>()) {
2368 isUnsigned = false;
2369 intVal = si.getIn();
2370 } else if (auto ui = op.getLhs().getDefiningOp<UIToFPOp>()) {
2371 isUnsigned = true;
2372 intVal = ui.getIn();
2373 } else {
2374 return failure();
2375 }
2376
2377 // Check to see that the input is converted from an integer type that is
2378 // small enough that preserves all bits.
2379 auto intTy = llvm::cast<IntegerType>(intVal.getType());
2380 auto intWidth = intTy.getWidth();
2381
2382 // Number of bits representing values, as opposed to the sign
2383 auto valueBits = isUnsigned ? intWidth : (intWidth - 1);
2384
2385 // Following test does NOT adjust intWidth downwards for signed inputs,
2386 // because the most negative value still requires all the mantissa bits
2387 // to distinguish it from one less than that value.
2388 if ((int)intWidth > mantissaWidth) {
2389 // Conversion would lose accuracy. Check if loss can impact comparison.
2390 int exponent = ilogb(rhs);
2391 if (exponent == APFloat::IEK_Inf) {
2392 int maxExponent = ilogb(APFloat::getLargest(rhs.getSemantics()));
2393 if (maxExponent < (int)valueBits) {
2394 // Conversion could create infinity.
2395 return failure();
2396 }
2397 } else {
2398 // Note that if rhs is zero or NaN, then Exp is negative
2399 // and first condition is trivially false.
2400 if (mantissaWidth <= exponent && exponent <= (int)valueBits) {
2401 // Conversion could affect comparison.
2402 return failure();
2403 }
2404 }
2405 }
2406
2407 // Convert to equivalent cmpi predicate
2408 CmpIPredicate pred;
2409 switch (op.getPredicate()) {
2410 case CmpFPredicate::ORD:
2411 // Int to fp conversion doesn't create a nan (ord checks neither is a nan)
2412 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2413 /*width=*/1);
2414 return success();
2415 case CmpFPredicate::UNO:
2416 // Int to fp conversion doesn't create a nan (uno checks either is a nan)
2417 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2418 /*width=*/1);
2419 return success();
2420 default:
2421 pred = convertToIntegerPredicate(op.getPredicate(), isUnsigned);
2422 break;
2423 }
2424
2425 if (!isUnsigned) {
2426 // If the rhs value is > SignedMax, fold the comparison. This handles
2427 // +INF and large values.
2428 APFloat signedMax(rhs.getSemantics());
2429 signedMax.convertFromAPInt(APInt::getSignedMaxValue(intWidth), true,
2430 APFloat::rmNearestTiesToEven);
2431 if (signedMax < rhs) { // smax < 13123.0
2432 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::slt ||
2433 pred == CmpIPredicate::sle)
2434 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2435 /*width=*/1);
2436 else
2437 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2438 /*width=*/1);
2439 return success();
2440 }
2441 } else {
2442 // If the rhs value is > UnsignedMax, fold the comparison. This handles
2443 // +INF and large values.
2444 APFloat unsignedMax(rhs.getSemantics());
2445 unsignedMax.convertFromAPInt(APInt::getMaxValue(intWidth), false,
2446 APFloat::rmNearestTiesToEven);
2447 if (unsignedMax < rhs) { // umax < 13123.0
2448 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ult ||
2449 pred == CmpIPredicate::ule)
2450 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2451 /*width=*/1);
2452 else
2453 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2454 /*width=*/1);
2455 return success();
2456 }
2457 }
2458
2459 if (!isUnsigned) {
2460 // See if the rhs value is < SignedMin.
2461 APFloat signedMin(rhs.getSemantics());
2462 signedMin.convertFromAPInt(APInt::getSignedMinValue(intWidth), true,
2463 APFloat::rmNearestTiesToEven);
2464 if (signedMin > rhs) { // smin > 12312.0
2465 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::sgt ||
2466 pred == CmpIPredicate::sge)
2467 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2468 /*width=*/1);
2469 else
2470 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2471 /*width=*/1);
2472 return success();
2473 }
2474 } else {
2475 // See if the rhs value is < UnsignedMin.
2476 APFloat unsignedMin(rhs.getSemantics());
2477 unsignedMin.convertFromAPInt(APInt::getMinValue(intWidth), false,
2478 APFloat::rmNearestTiesToEven);
2479 if (unsignedMin > rhs) { // umin > 12312.0
2480 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ugt ||
2481 pred == CmpIPredicate::uge)
2482 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2483 /*width=*/1);
2484 else
2485 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2486 /*width=*/1);
2487 return success();
2488 }
2489 }
2490
2491 // Okay, now we know that the FP constant fits in the range [SMIN, SMAX] or
2492 // [0, UMAX], but it may still be fractional. See if it is fractional by
2493 // casting the FP value to the integer value and back, checking for
2494 // equality. Don't do this for zero, because -0.0 is not fractional.
2495 bool ignored;
2496 APSInt rhsInt(intWidth, isUnsigned);
2497 if (APFloat::opInvalidOp ==
2498 rhs.convertToInteger(rhsInt, APFloat::rmTowardZero, &ignored)) {
2499 // Undefined behavior invoked - the destination type can't represent
2500 // the input constant.
2501 return failure();
2502 }
2503
2504 if (!rhs.isZero()) {
2505 APFloat apf(floatTy.getFloatSemantics(),
2506 APInt::getZero(floatTy.getWidth()));
2507 apf.convertFromAPInt(rhsInt, !isUnsigned, APFloat::rmNearestTiesToEven);
2508
2509 bool equal = apf == rhs;
2510 if (!equal) {
2511 // If we had a comparison against a fractional value, we have to adjust
2512 // the compare predicate and sometimes the value. rhsInt is rounded
2513 // towards zero at this point.
2514 switch (pred) {
2515 case CmpIPredicate::ne: // (float)int != 4.4 --> true
2516 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2517 /*width=*/1);
2518 return success();
2519 case CmpIPredicate::eq: // (float)int == 4.4 --> false
2520 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2521 /*width=*/1);
2522 return success();
2523 case CmpIPredicate::ule:
2524 // (float)int <= 4.4 --> int <= 4
2525 // (float)int <= -4.4 --> false
2526 if (rhs.isNegative()) {
2527 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2528 /*width=*/1);
2529 return success();
2530 }
2531 break;
2532 case CmpIPredicate::sle:
2533 // (float)int <= 4.4 --> int <= 4
2534 // (float)int <= -4.4 --> int < -4
2535 if (rhs.isNegative())
2536 pred = CmpIPredicate::slt;
2537 break;
2538 case CmpIPredicate::ult:
2539 // (float)int < -4.4 --> false
2540 // (float)int < 4.4 --> int <= 4
2541 if (rhs.isNegative()) {
2542 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2543 /*width=*/1);
2544 return success();
2545 }
2546 pred = CmpIPredicate::ule;
2547 break;
2548 case CmpIPredicate::slt:
2549 // (float)int < -4.4 --> int < -4
2550 // (float)int < 4.4 --> int <= 4
2551 if (!rhs.isNegative())
2552 pred = CmpIPredicate::sle;
2553 break;
2554 case CmpIPredicate::ugt:
2555 // (float)int > 4.4 --> int > 4
2556 // (float)int > -4.4 --> true
2557 if (rhs.isNegative()) {
2558 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2559 /*width=*/1);
2560 return success();
2561 }
2562 break;
2563 case CmpIPredicate::sgt:
2564 // (float)int > 4.4 --> int > 4
2565 // (float)int > -4.4 --> int >= -4
2566 if (rhs.isNegative())
2567 pred = CmpIPredicate::sge;
2568 break;
2569 case CmpIPredicate::uge:
2570 // (float)int >= -4.4 --> true
2571 // (float)int >= 4.4 --> int > 4
2572 if (rhs.isNegative()) {
2573 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2574 /*width=*/1);
2575 return success();
2576 }
2577 pred = CmpIPredicate::ugt;
2578 break;
2579 case CmpIPredicate::sge:
2580 // (float)int >= -4.4 --> int >= -4
2581 // (float)int >= 4.4 --> int > 4
2582 if (!rhs.isNegative())
2583 pred = CmpIPredicate::sgt;
2584 break;
2585 }
2586 }
2587 }
2588
2589 // Lower this FP comparison into an appropriate integer version of the
2590 // comparison.
2591 rewriter.replaceOpWithNewOp<CmpIOp>(
2592 op, pred, intVal,
2593 ConstantOp::create(rewriter, op.getLoc(), intVal.getType(),
2594 rewriter.getIntegerAttr(intVal.getType(), rhsInt)));
2595 return success();
2596 }
2597};
2598
2599void arith::CmpFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
2600 MLIRContext *context) {
2601 patterns.insert<CmpFIntToFPConst>(context);
2602}
2603
2604//===----------------------------------------------------------------------===//
2605// SelectOp
2606//===----------------------------------------------------------------------===//
2607
2608// select %arg, %c1, %c0 => extui %arg
2609struct SelectToExtUI : public OpRewritePattern<arith::SelectOp> {
2610 using Base::Base;
2611
2612 LogicalResult matchAndRewrite(arith::SelectOp op,
2613 PatternRewriter &rewriter) const override {
2614 // Cannot extui i1 to i1, or i1 to f32
2615 if (!llvm::isa<IntegerType>(op.getType()) || op.getType().isInteger(1))
2616 return failure();
2617
2618 // select %x, c1, %c0 => extui %arg
2619 if (matchPattern(op.getTrueValue(), m_One()) &&
2620 matchPattern(op.getFalseValue(), m_Zero())) {
2621 rewriter.replaceOpWithNewOp<arith::ExtUIOp>(op, op.getType(),
2622 op.getCondition());
2623 return success();
2624 }
2625
2626 // select %x, c0, %c1 => extui (xor %arg, true)
2627 if (matchPattern(op.getTrueValue(), m_Zero()) &&
2628 matchPattern(op.getFalseValue(), m_One())) {
2629 rewriter.replaceOpWithNewOp<arith::ExtUIOp>(
2630 op, op.getType(),
2631 arith::XOrIOp::create(
2632 rewriter, op.getLoc(), op.getCondition(),
2633 arith::ConstantIntOp::create(rewriter, op.getLoc(),
2634 op.getCondition().getType(), 1)));
2635 return success();
2636 }
2637
2638 return failure();
2639 }
2640};
2641
2642void arith::SelectOp::getCanonicalizationPatterns(RewritePatternSet &results,
2643 MLIRContext *context) {
2644 results.add<RedundantSelectFalse, RedundantSelectTrue, SelectNotCond,
2645 SelectI1ToNot, SelectToExtUI>(context);
2646}
2647
2648OpFoldResult arith::SelectOp::fold(FoldAdaptor adaptor) {
2649 Value trueVal = getTrueValue();
2650 Value falseVal = getFalseValue();
2651 if (trueVal == falseVal)
2652 return trueVal;
2653
2654 Value condition = getCondition();
2655
2656 // select true, %0, %1 => %0
2657 if (matchPattern(adaptor.getCondition(), m_One()))
2658 return trueVal;
2659
2660 // select false, %0, %1 => %1
2661 if (matchPattern(adaptor.getCondition(), m_Zero()))
2662 return falseVal;
2663
2664 // If either operand is fully poisoned, return the other.
2665 if (matchPattern(adaptor.getTrueValue(), ub::m_Poison()))
2666 return falseVal;
2667
2668 if (matchPattern(adaptor.getFalseValue(), ub::m_Poison()))
2669 return trueVal;
2670
2671 // select %x, true, false => %x
2672 if (getType().isSignlessInteger(1) &&
2673 matchPattern(adaptor.getTrueValue(), m_One()) &&
2674 matchPattern(adaptor.getFalseValue(), m_Zero()))
2675 return condition;
2676
2677 if (auto cmp = condition.getDefiningOp<arith::CmpIOp>()) {
2678 auto pred = cmp.getPredicate();
2679 if (pred == arith::CmpIPredicate::eq || pred == arith::CmpIPredicate::ne) {
2680 auto cmpLhs = cmp.getLhs();
2681 auto cmpRhs = cmp.getRhs();
2682
2683 // %0 = arith.cmpi eq, %arg0, %arg1
2684 // %1 = arith.select %0, %arg0, %arg1 => %arg1
2685
2686 // %0 = arith.cmpi ne, %arg0, %arg1
2687 // %1 = arith.select %0, %arg0, %arg1 => %arg0
2688
2689 if ((cmpLhs == trueVal && cmpRhs == falseVal) ||
2690 (cmpRhs == trueVal && cmpLhs == falseVal))
2691 return pred == arith::CmpIPredicate::ne ? trueVal : falseVal;
2692 }
2693 }
2694
2695 // Constant-fold constant operands over non-splat constant condition.
2696 // select %cst_vec, %cst0, %cst1 => %cst2
2697 if (auto cond =
2698 dyn_cast_if_present<DenseElementsAttr>(adaptor.getCondition())) {
2699 // DenseElementsAttr by construction always has a static shape.
2700 assert(cond.getType().hasStaticShape() &&
2701 "DenseElementsAttr must have static shape");
2702 if (auto lhs =
2703 dyn_cast_if_present<DenseElementsAttr>(adaptor.getTrueValue())) {
2704 if (auto rhs =
2705 dyn_cast_if_present<DenseElementsAttr>(adaptor.getFalseValue())) {
2706 SmallVector<Attribute> results;
2707 results.reserve(static_cast<size_t>(cond.getNumElements()));
2708 auto condVals = llvm::make_range(cond.value_begin<BoolAttr>(),
2709 cond.value_end<BoolAttr>());
2710 auto lhsVals = llvm::make_range(lhs.value_begin<Attribute>(),
2711 lhs.value_end<Attribute>());
2712 auto rhsVals = llvm::make_range(rhs.value_begin<Attribute>(),
2713 rhs.value_end<Attribute>());
2714
2715 for (auto [condVal, lhsVal, rhsVal] :
2716 llvm::zip_equal(condVals, lhsVals, rhsVals))
2717 results.push_back(condVal.getValue() ? lhsVal : rhsVal);
2718
2719 return DenseElementsAttr::get(lhs.getType(), results);
2720 }
2721 }
2722 }
2723
2724 return nullptr;
2725}
2726
2727ParseResult SelectOp::parse(OpAsmParser &parser, OperationState &result) {
2728 Type conditionType, resultType;
2729 SmallVector<OpAsmParser::UnresolvedOperand, 3> operands;
2730 if (parser.parseOperandList(operands, /*requiredOperandCount=*/3) ||
2731 parser.parseOptionalAttrDict(result.attributes) ||
2732 parser.parseColonType(resultType))
2733 return failure();
2734
2735 // Check for the explicit condition type if this is a masked tensor or vector.
2736 if (succeeded(parser.parseOptionalComma())) {
2737 conditionType = resultType;
2738 if (parser.parseType(resultType))
2739 return failure();
2740 } else {
2741 conditionType = parser.getBuilder().getI1Type();
2742 }
2743
2744 result.addTypes(resultType);
2745 return parser.resolveOperands(operands,
2746 {conditionType, resultType, resultType},
2747 parser.getNameLoc(), result.operands);
2748}
2749
2750void arith::SelectOp::print(OpAsmPrinter &p) {
2751 p << " " << getOperands();
2752 p.printOptionalAttrDict((*this)->getAttrs());
2753 p << " : ";
2754 if (ShapedType condType = dyn_cast<ShapedType>(getCondition().getType()))
2755 p << condType << ", ";
2756 p << getType();
2757}
2758
2759LogicalResult arith::SelectOp::verify() {
2760 Type conditionType = getCondition().getType();
2761 if (conditionType.isSignlessInteger(1))
2762 return success();
2763
2764 // If the result type is a vector or tensor, the type can be a mask with the
2765 // same elements.
2766 Type resultType = getType();
2767 if (!llvm::isa<TensorType, VectorType>(resultType))
2768 return emitOpError() << "expected condition to be a signless i1, but got "
2769 << conditionType;
2770 Type shapedConditionType = getI1SameShape(resultType);
2771 if (conditionType != shapedConditionType) {
2772 return emitOpError() << "expected condition type to have the same shape "
2773 "as the result type, expected "
2774 << shapedConditionType << ", but got "
2775 << conditionType;
2776 }
2777 return success();
2778}
2779//===----------------------------------------------------------------------===//
2780// ShLIOp
2781//===----------------------------------------------------------------------===//
2782
2783OpFoldResult arith::ShLIOp::fold(FoldAdaptor adaptor) {
2784 // shli(x, 0) -> x
2785 if (matchPattern(adaptor.getRhs(), m_Zero()))
2786 return getLhs();
2787 // Don't fold if shifting more or equal than the bit width.
2788 bool bounded = false;
2790 adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
2791 bounded = b.ult(b.getBitWidth());
2792 return a.shl(b);
2793 });
2794 return bounded ? result : Attribute();
2795}
2796
2797//===----------------------------------------------------------------------===//
2798// ShRUIOp
2799//===----------------------------------------------------------------------===//
2800
2801OpFoldResult arith::ShRUIOp::fold(FoldAdaptor adaptor) {
2802 // shrui(x, 0) -> x
2803 if (matchPattern(adaptor.getRhs(), m_Zero()))
2804 return getLhs();
2805 // Don't fold if shifting more or equal than the bit width.
2806 bool bounded = false;
2808 adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
2809 bounded = b.ult(b.getBitWidth());
2810 return a.lshr(b);
2811 });
2812 return bounded ? result : Attribute();
2813}
2814
2815//===----------------------------------------------------------------------===//
2816// ShRSIOp
2817//===----------------------------------------------------------------------===//
2818
2819OpFoldResult arith::ShRSIOp::fold(FoldAdaptor adaptor) {
2820 // shrsi(x, 0) -> x
2821 if (matchPattern(adaptor.getRhs(), m_Zero()))
2822 return getLhs();
2823 // Don't fold if shifting more or equal than the bit width.
2824 bool bounded = false;
2826 adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
2827 bounded = b.ult(b.getBitWidth());
2828 return a.ashr(b);
2829 });
2830 return bounded ? result : Attribute();
2831}
2832
2833//===----------------------------------------------------------------------===//
2834// Atomic Enum
2835//===----------------------------------------------------------------------===//
2836
2837/// Returns the identity value attribute associated with an AtomicRMWKind op.
2838TypedAttr mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
2839 OpBuilder &builder, Location loc,
2840 bool useOnlyFiniteValue) {
2841 switch (kind) {
2842 case AtomicRMWKind::maximumf: {
2843 const llvm::fltSemantics &semantic =
2844 llvm::cast<FloatType>(resultType).getFloatSemantics();
2845 APFloat identity = useOnlyFiniteValue
2846 ? APFloat::getLargest(semantic, /*Negative=*/true)
2847 : APFloat::getInf(semantic, /*Negative=*/true);
2848 return builder.getFloatAttr(resultType, identity);
2849 }
2850 case AtomicRMWKind::maxnumf: {
2851 const llvm::fltSemantics &semantic =
2852 llvm::cast<FloatType>(resultType).getFloatSemantics();
2853 APFloat identity = APFloat::getNaN(semantic, /*Negative=*/true);
2854 return builder.getFloatAttr(resultType, identity);
2855 }
2856 case AtomicRMWKind::addf:
2857 case AtomicRMWKind::addi:
2858 case AtomicRMWKind::maxu:
2859 case AtomicRMWKind::ori:
2860 case AtomicRMWKind::xori:
2861 return builder.getZeroAttr(resultType);
2862 case AtomicRMWKind::andi:
2863 return builder.getIntegerAttr(
2864 resultType,
2865 APInt::getAllOnes(llvm::cast<IntegerType>(resultType).getWidth()));
2866 case AtomicRMWKind::maxs:
2867 return builder.getIntegerAttr(
2868 resultType, APInt::getSignedMinValue(
2869 llvm::cast<IntegerType>(resultType).getWidth()));
2870 case AtomicRMWKind::minimumf: {
2871 const llvm::fltSemantics &semantic =
2872 llvm::cast<FloatType>(resultType).getFloatSemantics();
2873 APFloat identity = useOnlyFiniteValue
2874 ? APFloat::getLargest(semantic, /*Negative=*/false)
2875 : APFloat::getInf(semantic, /*Negative=*/false);
2876
2877 return builder.getFloatAttr(resultType, identity);
2878 }
2879 case AtomicRMWKind::minnumf: {
2880 const llvm::fltSemantics &semantic =
2881 llvm::cast<FloatType>(resultType).getFloatSemantics();
2882 APFloat identity = APFloat::getNaN(semantic, /*Negative=*/false);
2883 return builder.getFloatAttr(resultType, identity);
2884 }
2885 case AtomicRMWKind::mins:
2886 return builder.getIntegerAttr(
2887 resultType, APInt::getSignedMaxValue(
2888 llvm::cast<IntegerType>(resultType).getWidth()));
2889 case AtomicRMWKind::minu:
2890 return builder.getIntegerAttr(
2891 resultType,
2892 APInt::getMaxValue(llvm::cast<IntegerType>(resultType).getWidth()));
2893 case AtomicRMWKind::muli:
2894 return builder.getIntegerAttr(resultType, 1);
2895 case AtomicRMWKind::mulf:
2896 return builder.getFloatAttr(resultType, 1);
2897 // TODO: Add remaining reduction operations.
2898 default:
2899 (void)emitOptionalError(loc, "Reduction operation type not supported");
2900 break;
2901 }
2902 return nullptr;
2903}
2904
2905/// Returns the identity numeric value of the given op.
2906std::optional<TypedAttr> mlir::arith::getNeutralElement(Operation *op) {
2907 std::optional<AtomicRMWKind> maybeKind =
2909 // Floating-point operations.
2910 .Case([](arith::AddFOp op) { return AtomicRMWKind::addf; })
2911 .Case([](arith::MulFOp op) { return AtomicRMWKind::mulf; })
2912 .Case([](arith::MaximumFOp op) { return AtomicRMWKind::maximumf; })
2913 .Case([](arith::MinimumFOp op) { return AtomicRMWKind::minimumf; })
2914 .Case([](arith::MaxNumFOp op) { return AtomicRMWKind::maxnumf; })
2915 .Case([](arith::MinNumFOp op) { return AtomicRMWKind::minnumf; })
2916 // Integer operations.
2917 .Case([](arith::AddIOp op) { return AtomicRMWKind::addi; })
2918 .Case([](arith::OrIOp op) { return AtomicRMWKind::ori; })
2919 .Case([](arith::XOrIOp op) { return AtomicRMWKind::xori; })
2920 .Case([](arith::AndIOp op) { return AtomicRMWKind::andi; })
2921 .Case([](arith::MaxUIOp op) { return AtomicRMWKind::maxu; })
2922 .Case([](arith::MinUIOp op) { return AtomicRMWKind::minu; })
2923 .Case([](arith::MaxSIOp op) { return AtomicRMWKind::maxs; })
2924 .Case([](arith::MinSIOp op) { return AtomicRMWKind::mins; })
2925 .Case([](arith::MulIOp op) { return AtomicRMWKind::muli; })
2926 .Default(std::nullopt);
2927 if (!maybeKind) {
2928 return std::nullopt;
2929 }
2930
2931 bool useOnlyFiniteValue = false;
2932 auto fmfOpInterface = dyn_cast<ArithFastMathInterface>(op);
2933 if (fmfOpInterface) {
2934 arith::FastMathFlagsAttr fmfAttr = fmfOpInterface.getFastMathFlagsAttr();
2935 useOnlyFiniteValue =
2936 bitEnumContainsAny(fmfAttr.getValue(), arith::FastMathFlags::ninf);
2937 }
2938
2939 // Builder only used as helper for attribute creation.
2940 OpBuilder b(op->getContext());
2941 Type resultType = op->getResult(0).getType();
2942
2943 return getIdentityValueAttr(*maybeKind, resultType, b, op->getLoc(),
2944 useOnlyFiniteValue);
2945}
2946
2947/// Returns the identity value associated with an AtomicRMWKind op.
2948Value mlir::arith::getIdentityValue(AtomicRMWKind op, Type resultType,
2949 OpBuilder &builder, Location loc,
2950 bool useOnlyFiniteValue) {
2951 if (auto attr =
2952getIdentityValueAttr(op, resultType, builder, loc, useOnlyFiniteValue))
2953 return arith::ConstantOp::create(builder, loc, attr);
2954 return {};
2955}
2956
2957/// Return the value obtained by applying the reduction operation kind
2958/// associated with a binary AtomicRMWKind op to `lhs` and `rhs`.
2960 Location loc, Value lhs, Value rhs) {
2961 switch (op) {
2962 case AtomicRMWKind::addf:
2963 return arith::AddFOp::create(builder, loc, lhs, rhs);
2964 case AtomicRMWKind::addi:
2965 return arith::AddIOp::create(builder, loc, lhs, rhs);
2966 case AtomicRMWKind::mulf:
2967 return arith::MulFOp::create(builder, loc, lhs, rhs);
2968 case AtomicRMWKind::muli:
2969 return arith::MulIOp::create(builder, loc, lhs, rhs);
2970 case AtomicRMWKind::maximumf:
2971 return arith::MaximumFOp::create(builder, loc, lhs, rhs);
2972 case AtomicRMWKind::minimumf:
2973 return arith::MinimumFOp::create(builder, loc, lhs, rhs);
2974 case AtomicRMWKind::maxnumf:
2975 return arith::MaxNumFOp::create(builder, loc, lhs, rhs);
2976 case AtomicRMWKind::minnumf:
2977 return arith::MinNumFOp::create(builder, loc, lhs, rhs);
2978 case AtomicRMWKind::maxs:
2979 return arith::MaxSIOp::create(builder, loc, lhs, rhs);
2980 case AtomicRMWKind::mins:
2981 return arith::MinSIOp::create(builder, loc, lhs, rhs);
2982 case AtomicRMWKind::maxu:
2983 return arith::MaxUIOp::create(builder, loc, lhs, rhs);
2984 case AtomicRMWKind::minu:
2985 return arith::MinUIOp::create(builder, loc, lhs, rhs);
2986 case AtomicRMWKind::ori:
2987 return arith::OrIOp::create(builder, loc, lhs, rhs);
2988 case AtomicRMWKind::andi:
2989 return arith::AndIOp::create(builder, loc, lhs, rhs);
2990 case AtomicRMWKind::xori:
2991 return arith::XOrIOp::create(builder, loc, lhs, rhs);
2992 // TODO: Add remaining reduction operations.
2993 default:
2994 (void)emitOptionalError(loc, "Reduction operation type not supported");
2995 break;
2996 }
2997 return nullptr;
2998}
2999
3000//===----------------------------------------------------------------------===//
3001// TableGen'd op method definitions
3002//===----------------------------------------------------------------------===//
3003
3004#define GET_OP_CLASSES
3005#include "mlir/Dialect/Arith/IR/ArithOps.cpp.inc"
3006
3007//===----------------------------------------------------------------------===//
3008// TableGen'd enum attribute definitions
3009//===----------------------------------------------------------------------===//
3010
3011#include "mlir/Dialect/Arith/IR/ArithOpsEnums.cpp.inc"
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static Speculation::Speculatability getDivUISpeculatability(Value divisor)
Returns whether an unsigned division by divisor is speculatable.
Definition ArithOps.cpp:727
static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs)
Validate a cast that changes the width of a type.
static IntegerAttr mulIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
Definition ArithOps.cpp:65
static IntegerOverflowFlagsAttr mergeOverflowFlags(IntegerOverflowFlagsAttr val1, IntegerOverflowFlagsAttr val2)
Definition ArithOps.cpp:72
static constexpr llvm::RoundingMode kDefaultRoundingMode
Default rounding mode according to default LLVM floating-point environment.
Definition ArithOps.cpp:38
static Type getTypeIfLike(Type type)
Get allowed underlying types for vectors and tensors.
static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate)
Returns true if the predicate is true for two equal operands.
static FailureOr< APFloat > convertFloatValue(APFloat sourceValue, const llvm::fltSemantics &targetSemantics, llvm::RoundingMode roundingMode=kDefaultRoundingMode)
Attempts to convert sourceValue to an APFloat value with targetSemantics and roundingMode,...
static Value foldDivMul(Value lhs, Value rhs, arith::IntegerOverflowFlags ovfFlags)
Fold (a * b) / b -> a
Definition ArithOps.cpp:688
static bool hasSameEncoding(Type typeA, Type typeB)
Return false if both types are ranked tensor with mismatching encoding.
static llvm::RoundingMode convertArithRoundingModeToLLVMIR(std::optional< RoundingMode > roundingMode)
Equivalent to convertRoundingModeToLLVM(convertArithRoundingModeToLLVM(roundingMode)).
Definition ArithOps.cpp:112
static Type getUnderlyingType(Type type, type_list< ShapedTypes... >, type_list< ElementTypes... >)
Returns a non-null type only if the provided type is one of the allowed types or one of the allowed s...
static APInt signedCeilNonnegInputs(const APInt &a, const APInt &b, bool &overflow)
Definition ArithOps.cpp:787
static std::optional< int64_t > getIntegerWidth(Type t)
static Speculation::Speculatability getDivSISpeculatability(Value divisor)
Returns whether a signed division by divisor is speculatable.
Definition ArithOps.cpp:769
static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
Definition ArithOps.cpp:55
static Attribute getBoolAttribute(Type type, bool value)
Definition ArithOps.cpp:155
static bool areIndexCastCompatible(TypeRange inputs, TypeRange outputs)
static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs)
static LogicalResult verifyExtOp(Op op)
static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
Definition ArithOps.cpp:60
static int64_t getScalarOrElementWidth(Type type)
Definition ArithOps.cpp:135
static Value foldAndIofAndI(arith::AndIOp op)
Fold and(a, and(a, b)) to and(a, b)
Definition ArithOps.cpp:981
static Type getTypeIfLikeOrMemRef(Type type)
Get allowed underlying types for vectors, tensors, and memrefs.
static Type getI1SameShape(Type type)
Return the type of the same shape (scalar, vector or tensor) containing i1.
Definition ArithOps.cpp:179
static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs)
std::tuple< Types... > * type_list
static IntegerAttr applyToIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs, function_ref< APInt(const APInt &, const APInt &)> binFn)
Definition ArithOps.cpp:46
static APInt calculateUnsignedOverflow(const APInt &sum, const APInt &operand)
Definition ArithOps.cpp:443
static FailureOr< APInt > getIntOrSplatIntValue(Attribute attr)
Definition ArithOps.cpp:147
static unsigned getIndexCastWidth(Type t)
Return the bit-width of t for the purpose of index_cast width checks.
static LogicalResult verifyTruncateOp(Op op)
lhs
static Type getElementType(Type type)
Determine the element type of type.
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
if(!isCopyOut)
b getContext())
#define mul(a, b)
#define add(a, b)
#define div(a, b)
LogicalResult matchAndRewrite(CmpFOp op, PatternRewriter &rewriter) const override
static CmpIPredicate convertToIntegerPredicate(CmpFPredicate pred, bool isUnsigned)
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
Attributes are known-constant values of operations.
Definition Attributes.h:25
static BoolAttr get(MLIRContext *context, bool value)
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:112
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:232
FloatAttr getFloatAttr(Type type, double value)
Definition Builders.cpp:258
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:71
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition Builders.h:93
TypedAttr getZeroAttr(Type type)
Definition Builders.cpp:328
IntegerType getI1Type()
Definition Builders.cpp:57
IndexType getIndexType()
Definition Builders.cpp:55
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Definition Builders.h:632
Location getLoc() const
Accessors for the implied location.
Definition Builders.h:665
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
This class helps build Operations.
Definition Builders.h:209
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition Builders.cpp:461
This class represents a single result from folding an operation.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:433
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:241
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:234
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:40
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition Types.cpp:35
bool isSignlessInteger() const
Return true if this is a signless integer type (with the specified width).
Definition Types.cpp:66
bool isIndex() const
Definition Types.cpp:56
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition Types.cpp:118
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:124
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
Specialization of arith.constant op that returns a floating point value.
Definition Arith.h:92
static ConstantFloatOp create(OpBuilder &builder, Location location, FloatType type, const APFloat &value)
Definition ArithOps.cpp:340
static bool classof(Operation *op)
Definition ArithOps.cpp:357
static void build(OpBuilder &builder, OperationState &result, FloatType type, const APFloat &value)
Build a constant float op that produces a float of the specified type.
Definition ArithOps.cpp:334
Specialization of arith.constant op that returns an integer of index type.
Definition Arith.h:113
static void build(OpBuilder &builder, OperationState &result, int64_t value)
Build a constant int op that produces an index.
Definition ArithOps.cpp:363
static bool classof(Operation *op)
Definition ArithOps.cpp:384
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:369
Specialization of arith.constant op that returns an integer value.
Definition Arith.h:54
static ConstantIntOp create(OpBuilder &builder, Location location, int64_t value, unsigned width)
Definition ArithOps.cpp:268
static void build(OpBuilder &builder, OperationState &result, int64_t value, unsigned width)
Build a constant int op that produces an integer of the specified width.
Definition ArithOps.cpp:261
static bool classof(Operation *op)
Definition ArithOps.cpp:328
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
std::optional< TypedAttr > getNeutralElement(Operation *op)
Return the identity numeric value associated to the give op.
bool applyCmpPredicate(arith::CmpIPredicate predicate, const APInt &lhs, const APInt &rhs)
Compute lhs pred rhs, where pred is one of the known integer comparison predicates.
TypedAttr getIdentityValueAttr(AtomicRMWKind kind, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value attribute associated with an AtomicRMWKind op.
Value getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc, Value lhs, Value rhs)
Returns the value obtained by applying the reduction operation kind associated with a binary AtomicRM...
Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value associated with an AtomicRMWKind op.
arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred)
Invert an integer comparison predicate.
Definition ArithOps.cpp:79
Value getZeroConstant(OpBuilder &builder, Location loc, Type type)
Creates an arith.constant operation with a zero value of type type.
Definition ArithOps.cpp:390
auto m_Val(Value v)
Definition Matchers.h:539
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
detail::poison_attr_matcher m_Poison()
Matches a poison constant (any attribute implementing PoisonAttrInterface).
Definition UBMatchers.h:46
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition Matchers.h:527
detail::constant_float_predicate_matcher m_NaNFloat()
Matches a constant scalar / vector splat / tensor splat float ones.
Definition Matchers.h:421
LogicalResult verifyCompatibleShapes(TypeRange types1, TypeRange types2)
Returns success if the given two arrays have the same number of elements and each pair wise entries h...
Attribute constFoldCastOp(ArrayRef< Attribute > operands, Type resType, CalculationT &&calculate)
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
Attribute constFoldBinaryOp(ArrayRef< Attribute > operands, Type resultType, CalculationT &&calculate)
detail::constant_int_range_predicate_matcher m_IntRangeWithoutNegOneS()
Matches a constant scalar / vector splat / tensor splat integer or a signed integer range that does n...
Definition Matchers.h:471
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
detail::constant_float_predicate_matcher m_PosZeroFloat()
Matches a constant scalar / vector splat / tensor splat float positive zero.
Definition Matchers.h:404
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Definition Matchers.h:442
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
detail::constant_float_predicate_matcher m_AnyZeroFloat()
Matches a constant scalar / vector splat / tensor splat float (both positive and negative) zero.
Definition Matchers.h:399
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
Definition Matchers.h:478
detail::constant_float_predicate_matcher m_NegInfFloat()
Matches a constant scalar / vector splat / tensor splat float negative infinity.
Definition Matchers.h:435
detail::constant_float_predicate_matcher m_NegZeroFloat()
Matches a constant scalar / vector splat / tensor splat float negative zero.
Definition Matchers.h:409
detail::constant_int_range_predicate_matcher m_IntRangeWithoutZeroS()
Matches a constant scalar / vector splat / tensor splat integer or a signed integer range that does n...
Definition Matchers.h:462
detail::op_matcher< OpClass > m_Op()
Matches the given OpClass.
Definition Matchers.h:484
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
Attribute constFoldUnaryOp(ArrayRef< Attribute > operands, Type resultType, CalculationT &&calculate)
detail::constant_float_predicate_matcher m_PosInfFloat()
Matches a constant scalar / vector splat / tensor splat float positive infinity.
Definition Matchers.h:427
llvm::function_ref< Fn > function_ref
Definition LLVM.h:147
detail::constant_float_predicate_matcher m_OneFloat()
Matches a constant scalar / vector splat / tensor splat float ones.
Definition Matchers.h:414
detail::constant_int_range_predicate_matcher m_IntRangeWithoutZeroU()
Matches a constant scalar / vector splat / tensor splat integer or a unsigned integer range that does...
Definition Matchers.h:455
LogicalResult matchAndRewrite(arith::SelectOp op, PatternRewriter &rewriter) const override
OpRewritePattern Base
Type alias to allow derived classes to inherit constructors with using Base::Base;.
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
This represents an operation in an abstracted form, suitable for use with the builder APIs.