MLIR 23.0.0git
ArithOps.cpp
Go to the documentation of this file.
1//===- ArithOps.cpp - MLIR Arith dialect ops implementation -----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include <cassert>
10#include <cstdint>
11#include <functional>
12#include <utility>
13
17#include "mlir/IR/Builders.h"
20#include "mlir/IR/Matchers.h"
25
26#include "llvm/ADT/APFloat.h"
27#include "llvm/ADT/APInt.h"
28#include "llvm/ADT/APSInt.h"
29#include "llvm/ADT/FloatingPointMode.h"
30#include "llvm/ADT/STLExtras.h"
31#include "llvm/ADT/SmallVector.h"
32#include "llvm/ADT/TypeSwitch.h"
33
34using namespace mlir;
35using namespace mlir::arith;
36
37/// Default rounding mode according to default LLVM floating-point environment.
38static constexpr llvm::RoundingMode kDefaultRoundingMode =
39 llvm::RoundingMode::NearestTiesToEven;
40
41//===----------------------------------------------------------------------===//
42// Pattern helpers
43//===----------------------------------------------------------------------===//
44
45static IntegerAttr
48 function_ref<APInt(const APInt &, const APInt &)> binFn) {
49 const APInt &lhsVal = llvm::cast<IntegerAttr>(lhs).getValue();
50 const APInt &rhsVal = llvm::cast<IntegerAttr>(rhs).getValue();
51 APInt value = binFn(lhsVal, rhsVal);
52 return IntegerAttr::get(res.getType(), value);
53}
54
55static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res,
57 return applyToIntegerAttrs(builder, res, lhs, rhs, std::plus<APInt>());
58}
59
60static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res,
62 return applyToIntegerAttrs(builder, res, lhs, rhs, std::minus<APInt>());
63}
64
65static IntegerAttr mulIntegerAttrs(PatternRewriter &builder, Value res,
67 return applyToIntegerAttrs(builder, res, lhs, rhs, std::multiplies<APInt>());
68}
69
70// Merge overflow flags from 2 ops, selecting the most conservative combination.
71static IntegerOverflowFlagsAttr
72mergeOverflowFlags(IntegerOverflowFlagsAttr val1,
73 IntegerOverflowFlagsAttr val2) {
74 return IntegerOverflowFlagsAttr::get(val1.getContext(),
75 val1.getValue() & val2.getValue());
76}
77
78/// Invert an integer comparison predicate.
79arith::CmpIPredicate arith::invertPredicate(arith::CmpIPredicate pred) {
80 switch (pred) {
81 case arith::CmpIPredicate::eq:
82 return arith::CmpIPredicate::ne;
83 case arith::CmpIPredicate::ne:
84 return arith::CmpIPredicate::eq;
85 case arith::CmpIPredicate::slt:
86 return arith::CmpIPredicate::sge;
87 case arith::CmpIPredicate::sle:
88 return arith::CmpIPredicate::sgt;
89 case arith::CmpIPredicate::sgt:
90 return arith::CmpIPredicate::sle;
91 case arith::CmpIPredicate::sge:
92 return arith::CmpIPredicate::slt;
93 case arith::CmpIPredicate::ult:
94 return arith::CmpIPredicate::uge;
95 case arith::CmpIPredicate::ule:
96 return arith::CmpIPredicate::ugt;
97 case arith::CmpIPredicate::ugt:
98 return arith::CmpIPredicate::ule;
99 case arith::CmpIPredicate::uge:
100 return arith::CmpIPredicate::ult;
101 }
102 llvm_unreachable("unknown cmpi predicate kind");
103}
104
105/// Equivalent to
106/// convertRoundingModeToLLVM(convertArithRoundingModeToLLVM(roundingMode)).
107///
108/// Not possible to implement as chain of calls as this would introduce a
109/// circular dependency with MLIRArithAttrToLLVMConversion and make arith depend
110/// on the LLVM dialect and on translation to LLVM.
111static llvm::RoundingMode
112convertArithRoundingModeToLLVMIR(std::optional<RoundingMode> roundingMode) {
113 if (!roundingMode)
115 switch (*roundingMode) {
116 case RoundingMode::downward:
117 return llvm::RoundingMode::TowardNegative;
118 case RoundingMode::to_nearest_away:
119 return llvm::RoundingMode::NearestTiesToAway;
120 case RoundingMode::to_nearest_even:
121 return llvm::RoundingMode::NearestTiesToEven;
122 case RoundingMode::toward_zero:
123 return llvm::RoundingMode::TowardZero;
124 case RoundingMode::upward:
125 return llvm::RoundingMode::TowardPositive;
126 }
127 llvm_unreachable("Unhandled rounding mode");
128}
129
130static arith::CmpIPredicateAttr invertPredicate(arith::CmpIPredicateAttr pred) {
131 return arith::CmpIPredicateAttr::get(pred.getContext(),
132 invertPredicate(pred.getValue()));
133}
134
136 Type elemTy = getElementTypeOrSelf(type);
137 if (elemTy.isIntOrFloat())
138 return elemTy.getIntOrFloatBitWidth();
139
140 return -1;
141}
142
144 return getScalarOrElementWidth(value.getType());
145}
146
147static FailureOr<APInt> getIntOrSplatIntValue(Attribute attr) {
148 APInt value;
149 if (matchPattern(attr, m_ConstantInt(&value)))
150 return value;
151
152 return failure();
153}
154
155static Attribute getBoolAttribute(Type type, bool value) {
156 auto boolAttr = BoolAttr::get(type.getContext(), value);
157 ShapedType shapedType = dyn_cast_or_null<ShapedType>(type);
158 if (!shapedType)
159 return boolAttr;
160 // DenseElementsAttr requires a static shape.
161 if (!shapedType.hasStaticShape())
162 return {};
163 return DenseElementsAttr::get(shapedType, boolAttr);
164}
165
166//===----------------------------------------------------------------------===//
167// TableGen'd canonicalization patterns
168//===----------------------------------------------------------------------===//
169
170namespace {
171#include "ArithCanonicalization.inc"
172} // namespace
173
174//===----------------------------------------------------------------------===//
175// Common helpers
176//===----------------------------------------------------------------------===//
177
178/// Return the type of the same shape (scalar, vector or tensor) containing i1.
180 auto i1Type = IntegerType::get(type.getContext(), 1);
181 if (auto shapedType = dyn_cast<ShapedType>(type))
182 return shapedType.cloneWith(std::nullopt, i1Type);
183 if (llvm::isa<UnrankedTensorType>(type))
184 return UnrankedTensorType::get(i1Type);
185 return i1Type;
186}
187
188//===----------------------------------------------------------------------===//
189// ConstantOp
190//===----------------------------------------------------------------------===//
191
192void arith::ConstantOp::getAsmResultNames(
193 function_ref<void(Value, StringRef)> setNameFn) {
194 auto type = getType();
195 if (auto intCst = dyn_cast<IntegerAttr>(getValue())) {
196 auto intType = dyn_cast<IntegerType>(type);
197
198 // Sugar i1 constants with 'true' and 'false'.
199 if (intType && intType.getWidth() == 1)
200 return setNameFn(getResult(), (intCst.getInt() ? "true" : "false"));
201
202 // Otherwise, build a complex name with the value and type.
203 SmallString<32> specialNameBuffer;
204 llvm::raw_svector_ostream specialName(specialNameBuffer);
205 specialName << 'c' << intCst.getValue();
206 if (intType)
207 specialName << '_' << type;
208 setNameFn(getResult(), specialName.str());
209 } else {
210 setNameFn(getResult(), "cst");
211 }
212}
213
214/// TODO: disallow arith.constant to return anything other than signless integer
215/// or float like.
216LogicalResult arith::ConstantOp::verify() {
217 auto type = getType();
218 // Integer values must be signless.
219 if (llvm::isa<IntegerType>(type) &&
220 !llvm::cast<IntegerType>(type).isSignless())
221 return emitOpError("integer return type must be signless");
222 // Any float or elements attribute are acceptable.
223 if (!llvm::isa<IntegerAttr, FloatAttr, ElementsAttr>(getValue())) {
224 return emitOpError(
225 "value must be an integer, float, or elements attribute");
226 }
227
228 // Note, we could relax this for vectors with 1 scalable dim, e.g.:
229 // * arith.constant dense<[[3, 3], [1, 1]]> : vector<2 x [2] x i32>
230 // However, this would most likely require updating the lowerings to LLVM.
231 if (isa<ScalableVectorType>(type) && !isa<SplatElementsAttr>(getValue()))
232 return emitOpError(
233 "initializing scalable vectors with elements attribute is not supported"
234 " unless it's a vector splat");
235 return success();
236}
237
238bool arith::ConstantOp::isBuildableWith(Attribute value, Type type) {
239 // The value's type must be the same as the provided type.
240 auto typedAttr = dyn_cast<TypedAttr>(value);
241 if (!typedAttr || typedAttr.getType() != type)
242 return false;
243 // Integer values must be signless.
244 if (auto intType = dyn_cast<IntegerType>(getElementTypeOrSelf(type))) {
245 if (!intType.isSignless())
246 return false;
247 }
248 // Integer, float, and element attributes are buildable.
249 return llvm::isa<IntegerAttr, FloatAttr, ElementsAttr>(value);
250}
251
252ConstantOp arith::ConstantOp::materialize(OpBuilder &builder, Attribute value,
253 Type type, Location loc) {
254 if (isBuildableWith(value, type))
255 return arith::ConstantOp::create(builder, loc, cast<TypedAttr>(value));
256 return nullptr;
257}
258
259OpFoldResult arith::ConstantOp::fold(FoldAdaptor adaptor) { return getValue(); }
260
262 int64_t value, unsigned width) {
263 auto type = builder.getIntegerType(width);
264 arith::ConstantOp::build(builder, result, type,
265 builder.getIntegerAttr(type, value));
266}
267
269 Location location,
271 unsigned width) {
272 mlir::OperationState state(location, getOperationName());
273 build(builder, state, value, width);
274 auto result = dyn_cast<ConstantIntOp>(builder.create(state));
275 assert(result && "builder didn't return the right type");
276 return result;
277}
278
281 unsigned width) {
282 return create(builder, builder.getLoc(), value, width);
283}
284
286 Type type, int64_t value) {
287 arith::ConstantOp::build(builder, result, type,
288 builder.getIntegerAttr(type, value));
289}
290
292 Location location, Type type,
293 int64_t value) {
294 mlir::OperationState state(location, getOperationName());
295 build(builder, state, type, value);
296 auto result = dyn_cast<ConstantIntOp>(builder.create(state));
297 assert(result && "builder didn't return the right type");
298 return result;
299}
300
302 Type type, int64_t value) {
303 return create(builder, builder.getLoc(), type, value);
304}
305
307 Type type, const APInt &value) {
308 arith::ConstantOp::build(builder, result, type,
309 builder.getIntegerAttr(type, value));
310}
311
313 Location location, Type type,
314 const APInt &value) {
315 mlir::OperationState state(location, getOperationName());
316 build(builder, state, type, value);
317 auto result = dyn_cast<ConstantIntOp>(builder.create(state));
318 assert(result && "builder didn't return the right type");
319 return result;
320}
321
323 Type type,
324 const APInt &value) {
325 return create(builder, builder.getLoc(), type, value);
326}
327
329 if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
330 return constOp.getType().isSignlessInteger();
331 return false;
332}
333
335 FloatType type, const APFloat &value) {
336 arith::ConstantOp::build(builder, result, type,
337 builder.getFloatAttr(type, value));
338}
339
341 Location location,
342 FloatType type,
343 const APFloat &value) {
344 mlir::OperationState state(location, getOperationName());
345 build(builder, state, type, value);
346 auto result = dyn_cast<ConstantFloatOp>(builder.create(state));
347 assert(result && "builder didn't return the right type");
348 return result;
349}
350
353 const APFloat &value) {
354 return create(builder, builder.getLoc(), type, value);
355}
356
358 if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
359 return llvm::isa<FloatType>(constOp.getType());
360 return false;
361}
362
364 int64_t value) {
365 arith::ConstantOp::build(builder, result, builder.getIndexType(),
366 builder.getIndexAttr(value));
367}
368
370 Location location,
371 int64_t value) {
372 mlir::OperationState state(location, getOperationName());
373 build(builder, state, value);
374 auto result = dyn_cast<ConstantIndexOp>(builder.create(state));
375 assert(result && "builder didn't return the right type");
376 return result;
377}
378
381 return create(builder, builder.getLoc(), value);
382}
383
385 if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
386 return constOp.getType().isIndex();
387 return false;
388}
389
391 Type type) {
392 // TODO: Incorporate this check to `FloatAttr::get*`.
393 assert(!isa<Float8E8M0FNUType>(getElementTypeOrSelf(type)) &&
394 "type doesn't have a zero representation");
395 TypedAttr zeroAttr = builder.getZeroAttr(type);
396 assert(zeroAttr && "unsupported type for zero attribute");
397 return arith::ConstantOp::create(builder, loc, zeroAttr);
398}
399
400//===----------------------------------------------------------------------===//
401// AddIOp
402//===----------------------------------------------------------------------===//
403
404OpFoldResult arith::AddIOp::fold(FoldAdaptor adaptor) {
405 // addi(x, 0) -> x
406 if (matchPattern(adaptor.getRhs(), m_Zero()))
407 return getLhs();
408
409 // addi(subi(a, b), b) -> a
410 if (auto sub = getLhs().getDefiningOp<SubIOp>())
411 if (getRhs() == sub.getRhs())
412 return sub.getLhs();
413
414 // addi(b, subi(a, b)) -> a
415 if (auto sub = getRhs().getDefiningOp<SubIOp>())
416 if (getLhs() == sub.getRhs())
417 return sub.getLhs();
418
420 adaptor.getOperands(),
421 [](APInt a, const APInt &b) { return std::move(a) + b; });
422}
423
424void arith::AddIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
425 MLIRContext *context) {
426 patterns.add<AddIAddConstant, AddISubConstantRHS, AddISubConstantLHS,
427 AddIMulNegativeOneRhs, AddIMulNegativeOneLhs>(context);
428}
429
430//===----------------------------------------------------------------------===//
431// AddUIExtendedOp
432//===----------------------------------------------------------------------===//
433
434std::optional<SmallVector<int64_t, 4>>
435arith::AddUIExtendedOp::getShapeForUnroll() {
436 if (auto vt = dyn_cast<VectorType>(getType(0)))
437 return llvm::to_vector<4>(vt.getShape());
438 return std::nullopt;
439}
440
441// Returns the overflow bit, assuming that `sum` is the result of unsigned
442// addition of `operand` and another number.
443static APInt calculateUnsignedOverflow(const APInt &sum, const APInt &operand) {
444 return sum.ult(operand) ? APInt::getAllOnes(1) : APInt::getZero(1);
445}
446
447LogicalResult
448arith::AddUIExtendedOp::fold(FoldAdaptor adaptor,
449 SmallVectorImpl<OpFoldResult> &results) {
450 Type overflowTy = getOverflow().getType();
451 // addui_extended(x, 0) -> x, false
452 if (matchPattern(getRhs(), m_Zero())) {
453 Builder builder(getContext());
454 auto falseValue = builder.getZeroAttr(overflowTy);
455
456 results.push_back(getLhs());
457 results.push_back(falseValue);
458 return success();
459 }
460
461 // addui_extended(constant_a, constant_b) -> constant_sum, constant_carry
462 // Let the `constFoldBinaryOp` utility attempt to fold the sum of both
463 // operands. If that succeeds, calculate the overflow bit based on the sum
464 // and the first (constant) operand, `lhs`.
465 if (Attribute sumAttr = constFoldBinaryOp<IntegerAttr>(
466 adaptor.getOperands(),
467 [](APInt a, const APInt &b) { return std::move(a) + b; })) {
468 // If any operand is poison, propagate poison to both results.
469 if (matchPattern(sumAttr, ub::m_Poison())) {
470 results.push_back(sumAttr);
471 results.push_back(sumAttr);
472 return success();
473 }
474 Attribute overflowAttr = constFoldBinaryOp<IntegerAttr>(
475 ArrayRef({sumAttr, adaptor.getLhs()}),
476 getI1SameShape(llvm::cast<TypedAttr>(sumAttr).getType()),
478 if (!overflowAttr)
479 return failure();
480
481 results.push_back(sumAttr);
482 results.push_back(overflowAttr);
483 return success();
484 }
485
486 return failure();
487}
488
489void arith::AddUIExtendedOp::getCanonicalizationPatterns(
490 RewritePatternSet &patterns, MLIRContext *context) {
491 patterns.add<AddUIExtendedToAddI>(context);
492}
493
494//===----------------------------------------------------------------------===//
495// SubIOp
496//===----------------------------------------------------------------------===//
497
498OpFoldResult arith::SubIOp::fold(FoldAdaptor adaptor) {
499 // subi(x,x) -> 0
500 if (getOperand(0) == getOperand(1)) {
501 auto shapedType = dyn_cast<ShapedType>(getType());
502 // We can't generate a constant with a dynamic shaped tensor.
503 if (!shapedType || shapedType.hasStaticShape())
504 return Builder(getContext()).getZeroAttr(getType());
505 }
506 // subi(x,0) -> x
507 if (matchPattern(adaptor.getRhs(), m_Zero()))
508 return getLhs();
509
510 if (auto add = getLhs().getDefiningOp<AddIOp>()) {
511 // subi(addi(a, b), b) -> a
512 if (getRhs() == add.getRhs())
513 return add.getLhs();
514 // subi(addi(a, b), a) -> b
515 if (getRhs() == add.getLhs())
516 return add.getRhs();
517 }
518
519 // subi(a, subi(a, b)) -> b
520 if (auto sub = getRhs().getDefiningOp<SubIOp>())
521 if (getLhs() == sub.getLhs())
522 return sub.getRhs();
523
525 adaptor.getOperands(),
526 [](APInt a, const APInt &b) { return std::move(a) - b; });
527}
528
529void arith::SubIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
530 MLIRContext *context) {
531 patterns.add<SubIRHSAddConstant, SubILHSAddConstant, SubIRHSSubConstantRHS,
532 SubIRHSSubConstantLHS, SubILHSSubConstantRHS,
533 SubILHSSubConstantLHS, SubISubILHSRHSLHS>(context);
534}
535
536//===----------------------------------------------------------------------===//
537// MulIOp
538//===----------------------------------------------------------------------===//
539
540OpFoldResult arith::MulIOp::fold(FoldAdaptor adaptor) {
541 // muli(x, 0) -> 0
542 if (matchPattern(adaptor.getRhs(), m_Zero()))
543 return getRhs();
544 // muli(x, 1) -> x
545 if (matchPattern(adaptor.getRhs(), m_One()))
546 return getLhs();
547 // TODO: Handle the overflow case.
548
549 // default folder
551 adaptor.getOperands(),
552 [](const APInt &a, const APInt &b) { return a * b; });
553}
554
555void arith::MulIOp::getAsmResultNames(
556 function_ref<void(Value, StringRef)> setNameFn) {
557 if (!isa<IndexType>(getType()))
558 return;
559
560 // Match vector.vscale by name to avoid depending on the vector dialect (which
561 // is a circular dependency).
562 auto isVscale = [](Operation *op) {
563 return op && op->getName().getStringRef() == "vector.vscale";
564 };
565
566 IntegerAttr baseValue;
567 auto isVscaleExpr = [&](Value a, Value b) {
568 return matchPattern(a, m_Constant(&baseValue)) &&
569 isVscale(b.getDefiningOp());
570 };
571
572 if (!isVscaleExpr(getLhs(), getRhs()) && !isVscaleExpr(getRhs(), getLhs()))
573 return;
574
575 // Name `base * vscale` or `vscale * base` as `c<base_value>_vscale`.
576 SmallString<32> specialNameBuffer;
577 llvm::raw_svector_ostream specialName(specialNameBuffer);
578 specialName << 'c' << baseValue.getInt() << "_vscale";
579 setNameFn(getResult(), specialName.str());
580}
581
582void arith::MulIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
583 MLIRContext *context) {
584 patterns.add<MulIMulIConstant>(context);
585}
586
587//===----------------------------------------------------------------------===//
588// MulSIExtendedOp
589//===----------------------------------------------------------------------===//
590
591std::optional<SmallVector<int64_t, 4>>
592arith::MulSIExtendedOp::getShapeForUnroll() {
593 if (auto vt = dyn_cast<VectorType>(getType(0)))
594 return llvm::to_vector<4>(vt.getShape());
595 return std::nullopt;
596}
597
598LogicalResult
599arith::MulSIExtendedOp::fold(FoldAdaptor adaptor,
600 SmallVectorImpl<OpFoldResult> &results) {
601 // mulsi_extended(x, 0) -> 0, 0
602 if (matchPattern(adaptor.getRhs(), m_Zero())) {
603 Attribute zero = adaptor.getRhs();
604 results.push_back(zero);
605 results.push_back(zero);
606 return success();
607 }
608
609 // mulsi_extended(cst_a, cst_b) -> cst_low, cst_high
610 if (Attribute lowAttr = constFoldBinaryOp<IntegerAttr>(
611 adaptor.getOperands(),
612 [](const APInt &a, const APInt &b) { return a * b; })) {
613 // Invoke the constant fold helper again to calculate the 'high' result.
614 Attribute highAttr = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
615 llvm::APIntOps::mulhs);
616 assert(highAttr && "Unexpected constant-folding failure");
617
618 results.push_back(lowAttr);
619 results.push_back(highAttr);
620 return success();
621 }
622
623 return failure();
624}
625
626void arith::MulSIExtendedOp::getCanonicalizationPatterns(
627 RewritePatternSet &patterns, MLIRContext *context) {
628 patterns.add<MulSIExtendedToMulI, MulSIExtendedRHSOne>(context);
629}
630
631//===----------------------------------------------------------------------===//
632// MulUIExtendedOp
633//===----------------------------------------------------------------------===//
634
635std::optional<SmallVector<int64_t, 4>>
636arith::MulUIExtendedOp::getShapeForUnroll() {
637 if (auto vt = dyn_cast<VectorType>(getType(0)))
638 return llvm::to_vector<4>(vt.getShape());
639 return std::nullopt;
640}
641
642LogicalResult
643arith::MulUIExtendedOp::fold(FoldAdaptor adaptor,
644 SmallVectorImpl<OpFoldResult> &results) {
645 // mului_extended(x, 0) -> 0, 0
646 if (matchPattern(adaptor.getRhs(), m_Zero())) {
647 Attribute zero = adaptor.getRhs();
648 results.push_back(zero);
649 results.push_back(zero);
650 return success();
651 }
652
653 // mului_extended(x, 1) -> x, 0
654 if (matchPattern(adaptor.getRhs(), m_One())) {
655 Builder builder(getContext());
656 Attribute zero = builder.getZeroAttr(getLhs().getType());
657 results.push_back(getLhs());
658 results.push_back(zero);
659 return success();
660 }
661
662 // mului_extended(cst_a, cst_b) -> cst_low, cst_high
663 if (Attribute lowAttr = constFoldBinaryOp<IntegerAttr>(
664 adaptor.getOperands(),
665 [](const APInt &a, const APInt &b) { return a * b; })) {
666 // Invoke the constant fold helper again to calculate the 'high' result.
667 Attribute highAttr = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
668 llvm::APIntOps::mulhu);
669 assert(highAttr && "Unexpected constant-folding failure");
670
671 results.push_back(lowAttr);
672 results.push_back(highAttr);
673 return success();
674 }
675
676 return failure();
677}
678
679void arith::MulUIExtendedOp::getCanonicalizationPatterns(
680 RewritePatternSet &patterns, MLIRContext *context) {
681 patterns.add<MulUIExtendedToMulI>(context);
682}
683
684//===----------------------------------------------------------------------===//
685// DivUIOp
686//===----------------------------------------------------------------------===//
687
688/// Fold `(a * b) / b -> a`
690 arith::IntegerOverflowFlags ovfFlags) {
691 auto mul = lhs.getDefiningOp<mlir::arith::MulIOp>();
692 if (!mul || !bitEnumContainsAll(mul.getOverflowFlags(), ovfFlags))
693 return {};
694
695 if (mul.getLhs() == rhs)
696 return mul.getRhs();
697
698 if (mul.getRhs() == rhs)
699 return mul.getLhs();
700
701 return {};
702}
703
704OpFoldResult arith::DivUIOp::fold(FoldAdaptor adaptor) {
705 // divui (x, 1) -> x.
706 if (matchPattern(adaptor.getRhs(), m_One()))
707 return getLhs();
708
709 // (a * b) / b -> a
710 if (Value val = foldDivMul(getLhs(), getRhs(), IntegerOverflowFlags::nuw))
711 return val;
712
713 // Don't fold if it would require a division by zero.
714 bool div0 = false;
715 auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
716 [&](APInt a, const APInt &b) {
717 if (div0 || !b) {
718 div0 = true;
719 return a;
720 }
721 return a.udiv(b);
722 });
723
724 return div0 ? Attribute() : result;
725}
726
727/// Returns whether an unsigned division by `divisor` is speculatable.
729 // X / 0 => UB
730 if (matchPattern(divisor, m_IntRangeWithoutZeroU()))
732
734}
735
736Speculation::Speculatability arith::DivUIOp::getSpeculatability() {
737 return getDivUISpeculatability(getRhs());
738}
739
740//===----------------------------------------------------------------------===//
741// DivSIOp
742//===----------------------------------------------------------------------===//
743
744OpFoldResult arith::DivSIOp::fold(FoldAdaptor adaptor) {
745 // divsi (x, 1) -> x.
746 if (matchPattern(adaptor.getRhs(), m_One()))
747 return getLhs();
748
749 // (a * b) / b -> a
750 if (Value val = foldDivMul(getLhs(), getRhs(), IntegerOverflowFlags::nsw))
751 return val;
752
753 // Don't fold if it would overflow or if it requires a division by zero.
754 bool overflowOrDiv0 = false;
756 adaptor.getOperands(), [&](APInt a, const APInt &b) {
757 if (overflowOrDiv0 || !b) {
758 overflowOrDiv0 = true;
759 return a;
760 }
761 return a.sdiv_ov(b, overflowOrDiv0);
762 });
763
764 return overflowOrDiv0 ? Attribute() : result;
765}
766
767/// Returns whether a signed division by `divisor` is speculatable. This
768/// function conservatively assumes that all signed division by -1 are not
769/// speculatable.
771 // X / 0 => UB
772 // INT_MIN / -1 => UB
773 if (matchPattern(divisor, m_IntRangeWithoutZeroS()) &&
776
778}
779
780Speculation::Speculatability arith::DivSIOp::getSpeculatability() {
781 return getDivSISpeculatability(getRhs());
782}
783
784//===----------------------------------------------------------------------===//
785// Ceil and floor division folding helpers
786//===----------------------------------------------------------------------===//
787
788static APInt signedCeilNonnegInputs(const APInt &a, const APInt &b,
789 bool &overflow) {
790 // Returns (a-1)/b + 1
791 APInt one(a.getBitWidth(), 1, true); // Signed value 1.
792 APInt val = a.ssub_ov(one, overflow).sdiv_ov(b, overflow);
793 return val.sadd_ov(one, overflow);
794}
795
796//===----------------------------------------------------------------------===//
797// CeilDivUIOp
798//===----------------------------------------------------------------------===//
799
800OpFoldResult arith::CeilDivUIOp::fold(FoldAdaptor adaptor) {
801 // ceildivui (x, 1) -> x.
802 if (matchPattern(adaptor.getRhs(), m_One()))
803 return getLhs();
804
805 bool overflowOrDiv0 = false;
807 adaptor.getOperands(), [&](APInt a, const APInt &b) {
808 if (overflowOrDiv0 || !b) {
809 overflowOrDiv0 = true;
810 return a;
811 }
812 APInt quotient = a.udiv(b);
813 if (!a.urem(b))
814 return quotient;
815 APInt one(a.getBitWidth(), 1, true);
816 return quotient.uadd_ov(one, overflowOrDiv0);
817 });
818
819 return overflowOrDiv0 ? Attribute() : result;
820}
821
822Speculation::Speculatability arith::CeilDivUIOp::getSpeculatability() {
823 return getDivUISpeculatability(getRhs());
824}
825
826//===----------------------------------------------------------------------===//
827// CeilDivSIOp
828//===----------------------------------------------------------------------===//
829
830OpFoldResult arith::CeilDivSIOp::fold(FoldAdaptor adaptor) {
831 // ceildivsi (x, 1) -> x.
832 if (matchPattern(adaptor.getRhs(), m_One()))
833 return getLhs();
834
835 // Don't fold if it would overflow or if it requires a division by zero.
836 // TODO: This hook won't fold operations where a = MININT, because
837 // negating MININT overflows. This can be improved.
838 bool overflowOrDiv0 = false;
840 adaptor.getOperands(), [&](APInt a, const APInt &b) {
841 if (overflowOrDiv0 || !b) {
842 overflowOrDiv0 = true;
843 return a;
844 }
845 if (!a)
846 return a;
847 // After this point we know that neither a or b are zero.
848 unsigned bits = a.getBitWidth();
849 APInt zero = APInt::getZero(bits);
850 bool aGtZero = a.sgt(zero);
851 bool bGtZero = b.sgt(zero);
852 if (aGtZero && bGtZero) {
853 // Both positive, return ceil(a, b).
854 return signedCeilNonnegInputs(a, b, overflowOrDiv0);
855 }
856
857 // No folding happens if any of the intermediate arithmetic operations
858 // overflows.
859 bool overflowNegA = false;
860 bool overflowNegB = false;
861 bool overflowDiv = false;
862 bool overflowNegRes = false;
863 if (!aGtZero && !bGtZero) {
864 // Both negative, return ceil(-a, -b).
865 APInt posA = zero.ssub_ov(a, overflowNegA);
866 APInt posB = zero.ssub_ov(b, overflowNegB);
867 APInt res = signedCeilNonnegInputs(posA, posB, overflowDiv);
868 overflowOrDiv0 = (overflowNegA || overflowNegB || overflowDiv);
869 return res;
870 }
871 if (!aGtZero && bGtZero) {
872 // A is negative, b is positive, return - ( -a / b).
873 APInt posA = zero.ssub_ov(a, overflowNegA);
874 APInt div = posA.sdiv_ov(b, overflowDiv);
875 APInt res = zero.ssub_ov(div, overflowNegRes);
876 overflowOrDiv0 = (overflowNegA || overflowDiv || overflowNegRes);
877 return res;
878 }
879 // A is positive, b is negative, return - (a / -b).
880 APInt posB = zero.ssub_ov(b, overflowNegB);
881 APInt div = a.sdiv_ov(posB, overflowDiv);
882 APInt res = zero.ssub_ov(div, overflowNegRes);
883
884 overflowOrDiv0 = (overflowNegB || overflowDiv || overflowNegRes);
885 return res;
886 });
887
888 return overflowOrDiv0 ? Attribute() : result;
889}
890
891Speculation::Speculatability arith::CeilDivSIOp::getSpeculatability() {
892 return getDivSISpeculatability(getRhs());
893}
894
895//===----------------------------------------------------------------------===//
896// FloorDivSIOp
897//===----------------------------------------------------------------------===//
898
899OpFoldResult arith::FloorDivSIOp::fold(FoldAdaptor adaptor) {
900 // floordivsi (x, 1) -> x.
901 if (matchPattern(adaptor.getRhs(), m_One()))
902 return getLhs();
903
904 // Don't fold if it would overflow or if it requires a division by zero.
905 bool overflowOrDiv = false;
907 adaptor.getOperands(), [&](APInt a, const APInt &b) {
908 if (b.isZero()) {
909 overflowOrDiv = true;
910 return a;
911 }
912 return a.sfloordiv_ov(b, overflowOrDiv);
913 });
914
915 return overflowOrDiv ? Attribute() : result;
916}
917
918//===----------------------------------------------------------------------===//
919// RemUIOp
920//===----------------------------------------------------------------------===//
921
922OpFoldResult arith::RemUIOp::fold(FoldAdaptor adaptor) {
923 // remui (x, 1) -> 0.
924 if (matchPattern(adaptor.getRhs(), m_One()))
925 return Builder(getContext()).getZeroAttr(getType());
926
927 // Don't fold if it would require a division by zero.
928 bool div0 = false;
929 auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
930 [&](APInt a, const APInt &b) {
931 if (div0 || b.isZero()) {
932 div0 = true;
933 return a;
934 }
935 return a.urem(b);
936 });
937
938 return div0 ? Attribute() : result;
939}
940
941Speculation::Speculatability arith::RemUIOp::getSpeculatability() {
942 return getDivUISpeculatability(getRhs());
943}
944
945//===----------------------------------------------------------------------===//
946// RemSIOp
947//===----------------------------------------------------------------------===//
948
949OpFoldResult arith::RemSIOp::fold(FoldAdaptor adaptor) {
950 // remsi (x, 1) -> 0.
951 if (matchPattern(adaptor.getRhs(), m_One()))
952 return Builder(getContext()).getZeroAttr(getType());
953
954 // Don't fold if it would require a division by zero.
955 bool div0 = false;
956 auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
957 [&](APInt a, const APInt &b) {
958 if (div0 || b.isZero()) {
959 div0 = true;
960 return a;
961 }
962 return a.srem(b);
963 });
964
965 return div0 ? Attribute() : result;
966}
967
968Speculation::Speculatability arith::RemSIOp::getSpeculatability() {
969 // X % 0 => UB
970 // X % -1 is well-defined (always 0), unlike X / -1 which can overflow.
971 if (matchPattern(getRhs(), m_IntRangeWithoutZeroS()))
973
975}
976
977//===----------------------------------------------------------------------===//
978// AndIOp
979//===----------------------------------------------------------------------===//
980
981/// Fold `and(a, and(a, b))` to `and(a, b)`
982static Value foldAndIofAndI(arith::AndIOp op) {
983 for (bool reversePrev : {false, true}) {
984 auto prev = (reversePrev ? op.getRhs() : op.getLhs())
985 .getDefiningOp<arith::AndIOp>();
986 if (!prev)
987 continue;
988
989 Value other = (reversePrev ? op.getLhs() : op.getRhs());
990 if (other != prev.getLhs() && other != prev.getRhs())
991 continue;
992
993 return prev.getResult();
994 }
995 return {};
996}
997
998OpFoldResult arith::AndIOp::fold(FoldAdaptor adaptor) {
999 /// and(x, 0) -> 0
1000 if (matchPattern(adaptor.getRhs(), m_Zero()))
1001 return getRhs();
1002 /// and(x, allOnes) -> x
1003 APInt intValue;
1004 if (matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue)) &&
1005 intValue.isAllOnes())
1006 return getLhs();
1007 /// and(x, not(x)) -> 0
1008 if (matchPattern(getRhs(), m_Op<XOrIOp>(matchers::m_Val(getLhs()),
1009 m_ConstantInt(&intValue))) &&
1010 intValue.isAllOnes())
1011 return Builder(getContext()).getZeroAttr(getType());
1012 /// and(not(x), x) -> 0
1013 if (matchPattern(getLhs(), m_Op<XOrIOp>(matchers::m_Val(getRhs()),
1014 m_ConstantInt(&intValue))) &&
1015 intValue.isAllOnes())
1016 return Builder(getContext()).getZeroAttr(getType());
1017
1018 /// and(a, and(a, b)) -> and(a, b)
1019 if (Value result = foldAndIofAndI(*this))
1020 return result;
1021
1023 adaptor.getOperands(),
1024 [](APInt a, const APInt &b) { return std::move(a) & b; });
1025}
1026
1027//===----------------------------------------------------------------------===//
1028// OrIOp
1029//===----------------------------------------------------------------------===//
1030
1031OpFoldResult arith::OrIOp::fold(FoldAdaptor adaptor) {
1032 if (APInt rhsVal; matchPattern(adaptor.getRhs(), m_ConstantInt(&rhsVal))) {
1033 /// or(x, 0) -> x
1034 if (rhsVal.isZero())
1035 return getLhs();
1036 /// or(x, <all ones>) -> <all ones>
1037 if (rhsVal.isAllOnes())
1038 return adaptor.getRhs();
1039 }
1040
1041 APInt intValue;
1042 /// or(x, xor(x, 1)) -> 1
1043 if (matchPattern(getRhs(), m_Op<XOrIOp>(matchers::m_Val(getLhs()),
1044 m_ConstantInt(&intValue))) &&
1045 intValue.isAllOnes())
1046 return getRhs().getDefiningOp<XOrIOp>().getRhs();
1047 /// or(xor(x, 1), x) -> 1
1048 if (matchPattern(getLhs(), m_Op<XOrIOp>(matchers::m_Val(getRhs()),
1049 m_ConstantInt(&intValue))) &&
1050 intValue.isAllOnes())
1051 return getLhs().getDefiningOp<XOrIOp>().getRhs();
1052
1054 adaptor.getOperands(),
1055 [](APInt a, const APInt &b) { return std::move(a) | b; });
1056}
1057
1058//===----------------------------------------------------------------------===//
1059// XOrIOp
1060//===----------------------------------------------------------------------===//
1061
1062OpFoldResult arith::XOrIOp::fold(FoldAdaptor adaptor) {
1063 /// xor(x, 0) -> x
1064 if (matchPattern(adaptor.getRhs(), m_Zero()))
1065 return getLhs();
1066 /// xor(x, x) -> 0
1067 if (getLhs() == getRhs())
1068 return Builder(getContext()).getZeroAttr(getType());
1069 /// xor(xor(x, a), a) -> x
1070 /// xor(xor(a, x), a) -> x
1071 if (arith::XOrIOp prev = getLhs().getDefiningOp<arith::XOrIOp>()) {
1072 if (prev.getRhs() == getRhs())
1073 return prev.getLhs();
1074 if (prev.getLhs() == getRhs())
1075 return prev.getRhs();
1076 }
1077 /// xor(a, xor(x, a)) -> x
1078 /// xor(a, xor(a, x)) -> x
1079 if (arith::XOrIOp prev = getRhs().getDefiningOp<arith::XOrIOp>()) {
1080 if (prev.getRhs() == getLhs())
1081 return prev.getLhs();
1082 if (prev.getLhs() == getLhs())
1083 return prev.getRhs();
1084 }
1085
1087 adaptor.getOperands(),
1088 [](APInt a, const APInt &b) { return std::move(a) ^ b; });
1089}
1090
1091void arith::XOrIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1092 MLIRContext *context) {
1093 patterns.add<XOrINotCmpI, XOrIOfExtUI, XOrIOfExtSI>(context);
1094}
1095
1096//===----------------------------------------------------------------------===//
1097// NegFOp
1098//===----------------------------------------------------------------------===//
1099
1100OpFoldResult arith::NegFOp::fold(FoldAdaptor adaptor) {
1101 /// negf(negf(x)) -> x
1102 if (auto op = this->getOperand().getDefiningOp<arith::NegFOp>())
1103 return op.getOperand();
1104 return constFoldUnaryOp<FloatAttr>(adaptor.getOperands(),
1105 [](const APFloat &a) { return -a; });
1106}
1107
1108//===----------------------------------------------------------------------===//
1109// FlushDenormalsOp
1110//===----------------------------------------------------------------------===//
1111
1112OpFoldResult arith::FlushDenormalsOp::fold(FoldAdaptor adaptor) {
1113 // TODO: Fold flush_denormals if the floating-point type does not support
1114 // denormals. There is currently no API to query this information from
1115 // APFloat.
1116
1117 // flush_denormals(flush_denormals(x)) -> flush_denormals(x)
1118 if (auto op = this->getOperand().getDefiningOp<arith::FlushDenormalsOp>())
1119 return op.getResult();
1120
1121 // Constant-fold flush_denormals if the operand is a constant.
1123 adaptor.getOperands(), [](const APFloat &a) {
1124 if (a.isDenormal())
1125 return APFloat::getZero(a.getSemantics(), a.isNegative());
1126 return a;
1127 });
1128}
1129
1130//===----------------------------------------------------------------------===//
1131// AddFOp
1132//===----------------------------------------------------------------------===//
1133
1134OpFoldResult arith::AddFOp::fold(FoldAdaptor adaptor) {
1135 // addf(x, -0) -> x
1136 if (matchPattern(adaptor.getRhs(), m_NegZeroFloat()))
1137 return getLhs();
1138
1139 auto rm = getRoundingmode();
1141 adaptor.getOperands(), [rm](const APFloat &a, const APFloat &b) {
1142 APFloat result(a);
1143 result.add(b, convertArithRoundingModeToLLVMIR(rm));
1144 return result;
1145 });
1146}
1147
1148//===----------------------------------------------------------------------===//
1149// SubFOp
1150//===----------------------------------------------------------------------===//
1151
1152OpFoldResult arith::SubFOp::fold(FoldAdaptor adaptor) {
1153 // subf(x, +0) -> x
1154 if (matchPattern(adaptor.getRhs(), m_PosZeroFloat()))
1155 return getLhs();
1156
1157 auto rm = getRoundingmode();
1159 adaptor.getOperands(), [rm](const APFloat &a, const APFloat &b) {
1160 APFloat result(a);
1161 result.subtract(b, convertArithRoundingModeToLLVMIR(rm));
1162 return result;
1163 });
1164}
1165
1166void arith::SubFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1167 MLIRContext *context) {
1168 patterns.add<SubFOfNegZero>(context);
1169}
1170
1171//===----------------------------------------------------------------------===//
1172// MaximumFOp
1173//===----------------------------------------------------------------------===//
1174
1175OpFoldResult arith::MaximumFOp::fold(FoldAdaptor adaptor) {
1176 // maximumf(x,x) -> x
1177 if (getLhs() == getRhs())
1178 return getRhs();
1179
1180 // maximumf(x, -inf) -> x
1181 if (matchPattern(adaptor.getRhs(), m_NegInfFloat()))
1182 return getLhs();
1183
1184 return constFoldBinaryOp<FloatAttr>(adaptor.getOperands(), llvm::maximum);
1185}
1186
1187//===----------------------------------------------------------------------===//
1188// MaxNumFOp
1189//===----------------------------------------------------------------------===//
1190
1191OpFoldResult arith::MaxNumFOp::fold(FoldAdaptor adaptor) {
1192 // maxnumf(x,x) -> x
1193 if (getLhs() == getRhs())
1194 return getRhs();
1195
1196 // maxnumf(x, NaN) -> x
1197 if (matchPattern(adaptor.getRhs(), m_NaNFloat()))
1198 return getLhs();
1199
1200 return constFoldBinaryOp<FloatAttr>(adaptor.getOperands(), llvm::maxnum);
1201}
1202
1203//===----------------------------------------------------------------------===//
1204// MaxSIOp
1205//===----------------------------------------------------------------------===//
1206
1207OpFoldResult MaxSIOp::fold(FoldAdaptor adaptor) {
1208 // maxsi(x,x) -> x
1209 if (getLhs() == getRhs())
1210 return getRhs();
1211
1212 if (APInt intValue;
1213 matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue))) {
1214 // maxsi(x,MAX_INT) -> MAX_INT
1215 if (intValue.isMaxSignedValue())
1216 return getRhs();
1217 // maxsi(x, MIN_INT) -> x
1218 if (intValue.isMinSignedValue())
1219 return getLhs();
1220 }
1221
1222 return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1223 llvm::APIntOps::smax);
1224}
1225
1226//===----------------------------------------------------------------------===//
1227// MaxUIOp
1228//===----------------------------------------------------------------------===//
1229
1230OpFoldResult MaxUIOp::fold(FoldAdaptor adaptor) {
1231 // maxui(x,x) -> x
1232 if (getLhs() == getRhs())
1233 return getRhs();
1234
1235 if (APInt intValue;
1236 matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue))) {
1237 // maxui(x,MAX_INT) -> MAX_INT
1238 if (intValue.isMaxValue())
1239 return getRhs();
1240 // maxui(x, MIN_INT) -> x
1241 if (intValue.isMinValue())
1242 return getLhs();
1243 }
1244
1245 return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1246 llvm::APIntOps::umax);
1247}
1248
1249//===----------------------------------------------------------------------===//
1250// MinimumFOp
1251//===----------------------------------------------------------------------===//
1252
1253OpFoldResult arith::MinimumFOp::fold(FoldAdaptor adaptor) {
1254 // minimumf(x,x) -> x
1255 if (getLhs() == getRhs())
1256 return getRhs();
1257
1258 // minimumf(x, +inf) -> x
1259 if (matchPattern(adaptor.getRhs(), m_PosInfFloat()))
1260 return getLhs();
1261
1262 return constFoldBinaryOp<FloatAttr>(adaptor.getOperands(), llvm::minimum);
1263}
1264
1265//===----------------------------------------------------------------------===//
1266// MinNumFOp
1267//===----------------------------------------------------------------------===//
1268
1269OpFoldResult arith::MinNumFOp::fold(FoldAdaptor adaptor) {
1270 // minnumf(x,x) -> x
1271 if (getLhs() == getRhs())
1272 return getRhs();
1273
1274 // minnumf(x, NaN) -> x
1275 if (matchPattern(adaptor.getRhs(), m_NaNFloat()))
1276 return getLhs();
1277
1278 return constFoldBinaryOp<FloatAttr>(adaptor.getOperands(), llvm::minnum);
1279}
1280
1281//===----------------------------------------------------------------------===//
1282// MinSIOp
1283//===----------------------------------------------------------------------===//
1284
1285OpFoldResult MinSIOp::fold(FoldAdaptor adaptor) {
1286 // minsi(x,x) -> x
1287 if (getLhs() == getRhs())
1288 return getRhs();
1289
1290 if (APInt intValue;
1291 matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue))) {
1292 // minsi(x,MIN_INT) -> MIN_INT
1293 if (intValue.isMinSignedValue())
1294 return getRhs();
1295 // minsi(x, MAX_INT) -> x
1296 if (intValue.isMaxSignedValue())
1297 return getLhs();
1298 }
1299
1300 return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1301 llvm::APIntOps::smin);
1302}
1303
1304//===----------------------------------------------------------------------===//
1305// MinUIOp
1306//===----------------------------------------------------------------------===//
1307
1308OpFoldResult MinUIOp::fold(FoldAdaptor adaptor) {
1309 // minui(x,x) -> x
1310 if (getLhs() == getRhs())
1311 return getRhs();
1312
1313 if (APInt intValue;
1314 matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue))) {
1315 // minui(x,MIN_INT) -> MIN_INT
1316 if (intValue.isMinValue())
1317 return getRhs();
1318 // minui(x, MAX_INT) -> x
1319 if (intValue.isMaxValue())
1320 return getLhs();
1321 }
1322
1323 return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1324 llvm::APIntOps::umin);
1325}
1326
1327//===----------------------------------------------------------------------===//
1328// MulFOp
1329//===----------------------------------------------------------------------===//
1330
1331OpFoldResult arith::MulFOp::fold(FoldAdaptor adaptor) {
1332 // mulf(x, 1) -> x
1333 if (matchPattern(adaptor.getRhs(), m_OneFloat()))
1334 return getLhs();
1335
1336 if (arith::bitEnumContainsAll(getFastmath(), arith::FastMathFlags::nnan |
1337 arith::FastMathFlags::nsz)) {
1338 // mulf(x, 0) -> 0
1339 if (matchPattern(adaptor.getRhs(), m_AnyZeroFloat()))
1340 return getRhs();
1341 }
1342
1343 auto rm = getRoundingmode();
1345 adaptor.getOperands(), [rm](const APFloat &a, const APFloat &b) {
1346 APFloat result(a);
1347 result.multiply(b, convertArithRoundingModeToLLVMIR(rm));
1348 return result;
1349 });
1350}
1351
1352void arith::MulFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1353 MLIRContext *context) {
1354 patterns.add<MulFOfNegF>(context);
1355}
1356
1357//===----------------------------------------------------------------------===//
1358// DivFOp
1359//===----------------------------------------------------------------------===//
1360
1361OpFoldResult arith::DivFOp::fold(FoldAdaptor adaptor) {
1362 // divf(x, 1) -> x
1363 if (matchPattern(adaptor.getRhs(), m_OneFloat()))
1364 return getLhs();
1365
1366 auto rm = getRoundingmode();
1368 adaptor.getOperands(), [rm](const APFloat &a, const APFloat &b) {
1369 APFloat result(a);
1370 result.divide(b, convertArithRoundingModeToLLVMIR(rm));
1371 return result;
1372 });
1373}
1374
1375void arith::DivFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1376 MLIRContext *context) {
1377 patterns.add<DivFOfNegF>(context);
1378}
1379
1380//===----------------------------------------------------------------------===//
1381// RemFOp
1382//===----------------------------------------------------------------------===//
1383
1384OpFoldResult arith::RemFOp::fold(FoldAdaptor adaptor) {
1385 return constFoldBinaryOp<FloatAttr>(adaptor.getOperands(),
1386 [](const APFloat &a, const APFloat &b) {
1387 APFloat result(a);
1388 // APFloat::mod() offers the remainder
1389 // behavior we want, i.e. the result has
1390 // the sign of LHS operand.
1391 (void)result.mod(b);
1392 return result;
1393 });
1394}
1395
1396//===----------------------------------------------------------------------===//
1397// Utility functions for verifying cast ops
1398//===----------------------------------------------------------------------===//
1399
1400template <typename... Types>
1401using type_list = std::tuple<Types...> *;
1402
1403/// Returns a non-null type only if the provided type is one of the allowed
1404/// types or one of the allowed shaped types of the allowed types. Returns the
1405/// element type if a valid shaped type is provided.
1406template <typename... ShapedTypes, typename... ElementTypes>
1409 if (llvm::isa<ShapedType>(type) && !llvm::isa<ShapedTypes...>(type))
1410 return {};
1411
1412 auto underlyingType = getElementTypeOrSelf(type);
1413 if (!llvm::isa<ElementTypes...>(underlyingType))
1414 return {};
1415
1416 return underlyingType;
1417}
1418
1419/// Get allowed underlying types for vectors and tensors.
1420template <typename... ElementTypes>
1425
1426/// Get allowed underlying types for vectors, tensors, and memrefs.
1427template <typename... ElementTypes>
1433
1434/// Return false if both types are ranked tensor with mismatching encoding.
1435static bool hasSameEncoding(Type typeA, Type typeB) {
1436 auto rankedTensorA = dyn_cast<RankedTensorType>(typeA);
1437 auto rankedTensorB = dyn_cast<RankedTensorType>(typeB);
1438 if (!rankedTensorA || !rankedTensorB)
1439 return true;
1440 return rankedTensorA.getEncoding() == rankedTensorB.getEncoding();
1441}
1442
1444 if (inputs.size() != 1 || outputs.size() != 1)
1445 return false;
1446 if (!hasSameEncoding(inputs.front(), outputs.front()))
1447 return false;
1448 return succeeded(verifyCompatibleShapes(inputs.front(), outputs.front()));
1449}
1450
1451//===----------------------------------------------------------------------===//
1452// Verifiers for integer and floating point extension/truncation ops
1453//===----------------------------------------------------------------------===//
1454
1455// Extend ops can only extend to a wider type.
1456template <typename ValType, typename Op>
1457static LogicalResult verifyExtOp(Op op) {
1458 Type srcType = getElementTypeOrSelf(op.getIn().getType());
1459 Type dstType = getElementTypeOrSelf(op.getType());
1460
1461 if (llvm::cast<ValType>(srcType).getWidth() >=
1462 llvm::cast<ValType>(dstType).getWidth())
1463 return op.emitError("result type ")
1464 << dstType << " must be wider than operand type " << srcType;
1465
1466 return success();
1467}
1468
1469// Truncate ops can only truncate to a shorter type.
1470template <typename ValType, typename Op>
1471static LogicalResult verifyTruncateOp(Op op) {
1472 Type srcType = getElementTypeOrSelf(op.getIn().getType());
1473 Type dstType = getElementTypeOrSelf(op.getType());
1474
1475 if (llvm::cast<ValType>(srcType).getWidth() <=
1476 llvm::cast<ValType>(dstType).getWidth())
1477 return op.emitError("result type ")
1478 << dstType << " must be shorter than operand type " << srcType;
1479
1480 return success();
1481}
1482
1483/// Validate a cast that changes the width of a type.
1484template <template <typename> class WidthComparator, typename... ElementTypes>
1485static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs) {
1486 if (!areValidCastInputsAndOutputs(inputs, outputs))
1487 return false;
1488
1489 auto srcType = getTypeIfLike<ElementTypes...>(inputs.front());
1490 auto dstType = getTypeIfLike<ElementTypes...>(outputs.front());
1491 if (!srcType || !dstType)
1492 return false;
1493
1494 return WidthComparator<unsigned>()(dstType.getIntOrFloatBitWidth(),
1495 srcType.getIntOrFloatBitWidth());
1496}
1497
1498/// Attempts to convert `sourceValue` to an APFloat value with
1499/// `targetSemantics` and `roundingMode`, without any information loss.
1500static FailureOr<APFloat>
1501convertFloatValue(APFloat sourceValue,
1502 const llvm::fltSemantics &targetSemantics,
1503 llvm::RoundingMode roundingMode = kDefaultRoundingMode) {
1504 // Reject special values that are not representable in the target type before
1505 // calling APFloat::convert, which would llvm_unreachable on them.
1506 using fltNonfiniteBehavior = llvm::fltNonfiniteBehavior;
1507 if (sourceValue.isInfinity() &&
1508 (targetSemantics.nonFiniteBehavior == fltNonfiniteBehavior::NanOnly ||
1509 targetSemantics.nonFiniteBehavior == fltNonfiniteBehavior::FiniteOnly))
1510 return failure();
1511 if (sourceValue.isNaN() &&
1512 targetSemantics.nonFiniteBehavior == fltNonfiniteBehavior::FiniteOnly)
1513 return failure();
1514
1515 bool losesInfo = false;
1516 auto status = sourceValue.convert(targetSemantics, roundingMode, &losesInfo);
1517 if (losesInfo || status != APFloat::opOK)
1518 return failure();
1519
1520 return sourceValue;
1521}
1522
1523//===----------------------------------------------------------------------===//
1524// ExtUIOp
1525//===----------------------------------------------------------------------===//
1526
1527OpFoldResult arith::ExtUIOp::fold(FoldAdaptor adaptor) {
1528 if (auto lhs = getIn().getDefiningOp<ExtUIOp>()) {
1529 getInMutable().assign(lhs.getIn());
1530 return getResult();
1531 }
1532
1533 Type resType = getElementTypeOrSelf(getType());
1534 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1536 adaptor.getOperands(), getType(),
1537 [bitWidth](const APInt &a, bool &castStatus) {
1538 return a.zext(bitWidth);
1539 });
1540}
1541
1542bool arith::ExtUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1544}
1545
1546LogicalResult arith::ExtUIOp::verify() {
1547 return verifyExtOp<IntegerType>(*this);
1548}
1549
1550//===----------------------------------------------------------------------===//
1551// ExtSIOp
1552//===----------------------------------------------------------------------===//
1553
1554OpFoldResult arith::ExtSIOp::fold(FoldAdaptor adaptor) {
1555 if (auto lhs = getIn().getDefiningOp<ExtSIOp>()) {
1556 getInMutable().assign(lhs.getIn());
1557 return getResult();
1558 }
1559
1560 Type resType = getElementTypeOrSelf(getType());
1561 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1563 adaptor.getOperands(), getType(),
1564 [bitWidth](const APInt &a, bool &castStatus) {
1565 return a.sext(bitWidth);
1566 });
1567}
1568
1569bool arith::ExtSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1571}
1572
1573void arith::ExtSIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1574 MLIRContext *context) {
1575 patterns.add<ExtSIOfExtUI>(context);
1576}
1577
1578LogicalResult arith::ExtSIOp::verify() {
1579 return verifyExtOp<IntegerType>(*this);
1580}
1581
1582//===----------------------------------------------------------------------===//
1583// ExtFOp
1584//===----------------------------------------------------------------------===//
1585
1586/// Fold extension of float constants when there is no information loss due the
1587/// difference in fp semantics.
1588OpFoldResult arith::ExtFOp::fold(FoldAdaptor adaptor) {
1589 if (auto truncFOp = getOperand().getDefiningOp<TruncFOp>()) {
1590 if (truncFOp.getOperand().getType() == getType()) {
1591 arith::FastMathFlags truncFMF =
1592 truncFOp.getFastmath().value_or(arith::FastMathFlags::none);
1593 bool isTruncContract =
1594 bitEnumContainsAll(truncFMF, arith::FastMathFlags::contract);
1595 arith::FastMathFlags extFMF =
1596 getFastmath().value_or(arith::FastMathFlags::none);
1597 bool isExtContract =
1598 bitEnumContainsAll(extFMF, arith::FastMathFlags::contract);
1599 if (isTruncContract && isExtContract) {
1600 return truncFOp.getOperand();
1601 }
1602 }
1603 }
1604
1605 auto resElemType = cast<FloatType>(getElementTypeOrSelf(getType()));
1606 const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
1608 adaptor.getOperands(), getType(),
1609 [&targetSemantics](const APFloat &a, bool &castStatus) {
1610 FailureOr<APFloat> result = convertFloatValue(a, targetSemantics);
1611 if (failed(result)) {
1612 castStatus = false;
1613 return a;
1614 }
1615 return *result;
1616 });
1617}
1618
1619bool arith::ExtFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1620 return checkWidthChangeCast<std::greater, FloatType>(inputs, outputs);
1621}
1622
1623LogicalResult arith::ExtFOp::verify() { return verifyExtOp<FloatType>(*this); }
1624
1625//===----------------------------------------------------------------------===//
1626// ScalingExtFOp
1627//===----------------------------------------------------------------------===//
1628
1629bool arith::ScalingExtFOp::areCastCompatible(TypeRange inputs,
1630 TypeRange outputs) {
1631 return checkWidthChangeCast<std::greater, FloatType>(inputs.front(), outputs);
1632}
1633
1634LogicalResult arith::ScalingExtFOp::verify() {
1635 return verifyExtOp<FloatType>(*this);
1636}
1637
1638//===----------------------------------------------------------------------===//
1639// TruncIOp
1640//===----------------------------------------------------------------------===//
1641
1642OpFoldResult arith::TruncIOp::fold(FoldAdaptor adaptor) {
1643 if (matchPattern(getOperand(), m_Op<arith::ExtUIOp>()) ||
1644 matchPattern(getOperand(), m_Op<arith::ExtSIOp>())) {
1645 Value src = getOperand().getDefiningOp()->getOperand(0);
1646 Type srcType = getElementTypeOrSelf(src.getType());
1647 Type dstType = getElementTypeOrSelf(getType());
1648 // trunci(zexti(a)) -> trunci(a)
1649 // trunci(sexti(a)) -> trunci(a)
1650 if (llvm::cast<IntegerType>(srcType).getWidth() >
1651 llvm::cast<IntegerType>(dstType).getWidth()) {
1652 setOperand(src);
1653 return getResult();
1654 }
1655
1656 // trunci(zexti(a)) -> a
1657 // trunci(sexti(a)) -> a
1658 if (srcType == dstType)
1659 return src;
1660 }
1661
1662 // trunci(trunci(a)) -> trunci(a))
1663 if (matchPattern(getOperand(), m_Op<arith::TruncIOp>())) {
1664 setOperand(getOperand().getDefiningOp()->getOperand(0));
1665 return getResult();
1666 }
1667
1668 Type resType = getElementTypeOrSelf(getType());
1669 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1671 adaptor.getOperands(), getType(),
1672 [bitWidth](const APInt &a, bool &castStatus) {
1673 return a.trunc(bitWidth);
1674 });
1675}
1676
1677bool arith::TruncIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1678 return checkWidthChangeCast<std::less, IntegerType>(inputs, outputs);
1679}
1680
1681void arith::TruncIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1682 MLIRContext *context) {
1683 patterns
1684 .add<TruncIExtSIToExtSI, TruncIExtUIToExtUI, TruncIShrSIToTrunciShrUI>(
1685 context);
1686}
1687
1688LogicalResult arith::TruncIOp::verify() {
1689 return verifyTruncateOp<IntegerType>(*this);
1690}
1691
1692//===----------------------------------------------------------------------===//
1693// TruncFOp
1694//===----------------------------------------------------------------------===//
1695
1696/// Perform safe const propagation for truncf, i.e., only propagate if FP value
1697/// can be represented without precision loss.
1698OpFoldResult arith::TruncFOp::fold(FoldAdaptor adaptor) {
1699 auto resElemType = cast<FloatType>(getElementTypeOrSelf(getType()));
1700 if (auto extOp = getOperand().getDefiningOp<arith::ExtFOp>()) {
1701 Value src = extOp.getIn();
1702 auto srcType = cast<FloatType>(getElementTypeOrSelf(src.getType()));
1703 auto intermediateType =
1704 cast<FloatType>(getElementTypeOrSelf(extOp.getType()));
1705 // Check if the srcType is representable in the intermediateType.
1706 if (llvm::APFloatBase::isRepresentableBy(
1707 srcType.getFloatSemantics(),
1708 intermediateType.getFloatSemantics())) {
1709 // truncf(extf(a)) -> truncf(a)
1710 if (srcType.getWidth() > resElemType.getWidth()) {
1711 setOperand(src);
1712 return getResult();
1713 }
1714
1715 // truncf(extf(a)) -> a
1716 if (srcType == resElemType)
1717 return src;
1718 }
1719 }
1720
1721 const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
1723 adaptor.getOperands(), getType(),
1724 [this, &targetSemantics](const APFloat &a, bool &castStatus) {
1725 llvm::RoundingMode llvmRoundingMode =
1726 convertArithRoundingModeToLLVMIR(getRoundingmode());
1727 FailureOr<APFloat> result =
1728 convertFloatValue(a, targetSemantics, llvmRoundingMode);
1729 if (failed(result)) {
1730 castStatus = false;
1731 return a;
1732 }
1733 return *result;
1734 });
1735}
1736
1737void arith::TruncFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1738 MLIRContext *context) {
1739 patterns.add<TruncFSIToFPToSIToFP, TruncFUIToFPToUIToFP>(context);
1740}
1741
1742bool arith::TruncFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1743 return checkWidthChangeCast<std::less, FloatType>(inputs, outputs);
1744}
1745
1746LogicalResult arith::TruncFOp::verify() {
1747 return verifyTruncateOp<FloatType>(*this);
1748}
1749
1750//===----------------------------------------------------------------------===//
1751// ConvertFOp
1752//===----------------------------------------------------------------------===//
1753
1754OpFoldResult arith::ConvertFOp::fold(FoldAdaptor adaptor) {
1755 auto resElemType = cast<FloatType>(getElementTypeOrSelf(getType()));
1756 const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
1758 adaptor.getOperands(), getType(),
1759 [this, &targetSemantics](const APFloat &a, bool &castStatus) {
1760 llvm::RoundingMode llvmRoundingMode =
1761 convertArithRoundingModeToLLVMIR(getRoundingmode());
1762 FailureOr<APFloat> result =
1763 convertFloatValue(a, targetSemantics, llvmRoundingMode);
1764 if (failed(result)) {
1765 castStatus = false;
1766 return a;
1767 }
1768 return *result;
1769 });
1770}
1771
1772bool arith::ConvertFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1773 if (!areValidCastInputsAndOutputs(inputs, outputs))
1774 return false;
1775 auto srcType = getTypeIfLike<FloatType>(inputs.front());
1776 auto dstType = getTypeIfLike<FloatType>(outputs.front());
1777 if (!srcType || !dstType)
1778 return false;
1779 return srcType != dstType &&
1780 srcType.getIntOrFloatBitWidth() == dstType.getIntOrFloatBitWidth();
1781}
1782
1783LogicalResult arith::ConvertFOp::verify() {
1784 auto srcType = cast<FloatType>(getElementTypeOrSelf(getIn().getType()));
1785 auto dstType = cast<FloatType>(getElementTypeOrSelf(getType()));
1786 if (srcType == dstType)
1787 return emitError("result element type ")
1788 << dstType << " must be different from operand element type "
1789 << srcType;
1790 if (srcType.getWidth() != dstType.getWidth())
1791 return emitError("result element type ")
1792 << dstType << " must have the same bitwidth as operand element type "
1793 << srcType;
1794 return success();
1795}
1796
1797//===----------------------------------------------------------------------===//
1798// ScalingTruncFOp
1799//===----------------------------------------------------------------------===//
1800
1801bool arith::ScalingTruncFOp::areCastCompatible(TypeRange inputs,
1802 TypeRange outputs) {
1803 return checkWidthChangeCast<std::less, FloatType>(inputs.front(), outputs);
1804}
1805
1806LogicalResult arith::ScalingTruncFOp::verify() {
1807 return verifyTruncateOp<FloatType>(*this);
1808}
1809
1810//===----------------------------------------------------------------------===//
1811// AndIOp
1812//===----------------------------------------------------------------------===//
1813
1814void arith::AndIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1815 MLIRContext *context) {
1816 patterns.add<AndOfExtUI, AndOfExtSI>(context);
1817}
1818
1819//===----------------------------------------------------------------------===//
1820// OrIOp
1821//===----------------------------------------------------------------------===//
1822
1823void arith::OrIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1824 MLIRContext *context) {
1825 patterns.add<OrOfExtUI, OrOfExtSI>(context);
1826}
1827
1828//===----------------------------------------------------------------------===//
1829// Verifiers for casts between integers and floats.
1830//===----------------------------------------------------------------------===//
1831
1832template <typename From, typename To>
1833static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs) {
1834 if (!areValidCastInputsAndOutputs(inputs, outputs))
1835 return false;
1836
1837 auto srcType = getTypeIfLike<From>(inputs.front());
1838 auto dstType = getTypeIfLike<To>(outputs.back());
1839
1840 return srcType && dstType;
1841}
1842
1843//===----------------------------------------------------------------------===//
1844// UIToFPOp
1845//===----------------------------------------------------------------------===//
1846
1847bool arith::UIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1848 return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
1849}
1850
1851OpFoldResult arith::UIToFPOp::fold(FoldAdaptor adaptor) {
1852 Type resEleType = getElementTypeOrSelf(getType());
1854 adaptor.getOperands(), getType(),
1855 [&resEleType](const APInt &a, bool &castStatus) {
1856 FloatType floatTy = llvm::cast<FloatType>(resEleType);
1857 APFloat apf(floatTy.getFloatSemantics(),
1858 APInt::getZero(floatTy.getWidth()));
1859 apf.convertFromAPInt(a, /*IsSigned=*/false,
1860 APFloat::rmNearestTiesToEven);
1861 return apf;
1862 });
1863}
1864
1865void arith::UIToFPOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1866 MLIRContext *context) {
1867 patterns.add<UIToFPOfExtUI>(context);
1868}
1869
1870//===----------------------------------------------------------------------===//
1871// SIToFPOp
1872//===----------------------------------------------------------------------===//
1873
1874bool arith::SIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1875 return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
1876}
1877
1878OpFoldResult arith::SIToFPOp::fold(FoldAdaptor adaptor) {
1879 Type resEleType = getElementTypeOrSelf(getType());
1881 adaptor.getOperands(), getType(),
1882 [&resEleType](const APInt &a, bool &castStatus) {
1883 FloatType floatTy = llvm::cast<FloatType>(resEleType);
1884 APFloat apf(floatTy.getFloatSemantics(),
1885 APInt::getZero(floatTy.getWidth()));
1886 apf.convertFromAPInt(a, /*IsSigned=*/true,
1887 APFloat::rmNearestTiesToEven);
1888 return apf;
1889 });
1890}
1891
1892void arith::SIToFPOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1893 MLIRContext *context) {
1894 patterns.add<SIToFPOfExtSI, SIToFPOfExtUI>(context);
1895}
1896
1897//===----------------------------------------------------------------------===//
1898// FPToUIOp
1899//===----------------------------------------------------------------------===//
1900
1901bool arith::FPToUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1902 return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
1903}
1904
1905OpFoldResult arith::FPToUIOp::fold(FoldAdaptor adaptor) {
1906 Type resType = getElementTypeOrSelf(getType());
1907 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1909 adaptor.getOperands(), getType(),
1910 [&bitWidth](const APFloat &a, bool &castStatus) {
1911 bool ignored;
1912 APSInt api(bitWidth, /*isUnsigned=*/true);
1913 castStatus = APFloat::opInvalidOp !=
1914 a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
1915 return api;
1916 });
1917}
1918
1919//===----------------------------------------------------------------------===//
1920// FPToSIOp
1921//===----------------------------------------------------------------------===//
1922
1923bool arith::FPToSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1924 return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
1925}
1926
1927OpFoldResult arith::FPToSIOp::fold(FoldAdaptor adaptor) {
1928 Type resType = getElementTypeOrSelf(getType());
1929 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1931 adaptor.getOperands(), getType(),
1932 [&bitWidth](const APFloat &a, bool &castStatus) {
1933 bool ignored;
1934 APSInt api(bitWidth, /*isUnsigned=*/false);
1935 castStatus = APFloat::opInvalidOp !=
1936 a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
1937 return api;
1938 });
1939}
1940
1941//===----------------------------------------------------------------------===//
1942// IndexCastOp
1943//===----------------------------------------------------------------------===//
1944
1945/// Return the bit-width of \p t for the purpose of index_cast width checks.
1946/// For vector types use the element type; index maps to its internal storage
1947/// width (64 on all current targets).
1948static unsigned getIndexCastWidth(Type t) {
1949 if (auto intTy = dyn_cast<IntegerType>(getElementTypeOrSelf(t)))
1950 return intTy.getWidth();
1951 return IndexType::kInternalStorageBitWidth;
1952}
1953
1954static bool areIndexCastCompatible(TypeRange inputs, TypeRange outputs) {
1955 if (!areValidCastInputsAndOutputs(inputs, outputs))
1956 return false;
1957
1958 auto srcType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(inputs.front());
1959 auto dstType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(outputs.front());
1960 if (!srcType || !dstType)
1961 return false;
1962
1963 return (srcType.isIndex() && dstType.isSignlessInteger()) ||
1964 (srcType.isSignlessInteger() && dstType.isIndex());
1965}
1966
1967bool arith::IndexCastOp::areCastCompatible(TypeRange inputs,
1968 TypeRange outputs) {
1969 return areIndexCastCompatible(inputs, outputs);
1970}
1971
1972OpFoldResult arith::IndexCastOp::fold(FoldAdaptor adaptor) {
1973 // index_cast(constant) -> constant
1974 unsigned resultBitwidth = 64; // Default for index integer attributes.
1975 if (auto intTy = dyn_cast<IntegerType>(getElementTypeOrSelf(getType())))
1976 resultBitwidth = intTy.getWidth();
1977
1978 if (auto foldResult = constFoldCastOp<IntegerAttr, IntegerAttr>(
1979 adaptor.getOperands(), getType(),
1980 [resultBitwidth](const APInt &a, bool & /*castStatus*/) {
1981 return a.sextOrTrunc(resultBitwidth);
1982 }))
1983 return foldResult;
1984
1985 // index_cast(index_cast(x : A) : B) : A -> x, but only when B is at least
1986 // as wide as A. If B is narrower, the inner cast truncates and the outer
1987 // cast sign-extends, so the round-trip is lossy.
1988 if (auto inner = getOperand().getDefiningOp<arith::IndexCastOp>()) {
1989 Value x = inner.getOperand();
1990 if (x.getType() == getType()) {
1991 if (getIndexCastWidth(inner.getType()) >= getIndexCastWidth(x.getType()))
1992 return x;
1993 }
1994 }
1995 return {};
1996}
1997
1998void arith::IndexCastOp::getCanonicalizationPatterns(
1999 RewritePatternSet &patterns, MLIRContext *context) {
2000 patterns.add<IndexCastOfExtSI>(context);
2001}
2002
2003//===----------------------------------------------------------------------===//
2004// IndexCastUIOp
2005//===----------------------------------------------------------------------===//
2006
2007bool arith::IndexCastUIOp::areCastCompatible(TypeRange inputs,
2008 TypeRange outputs) {
2009 return areIndexCastCompatible(inputs, outputs);
2010}
2011
2012OpFoldResult arith::IndexCastUIOp::fold(FoldAdaptor adaptor) {
2013 // index_castui(constant) -> constant
2014 unsigned resultBitwidth = 64; // Default for index integer attributes.
2015 if (auto intTy = dyn_cast<IntegerType>(getElementTypeOrSelf(getType())))
2016 resultBitwidth = intTy.getWidth();
2017
2018 if (auto foldResult = constFoldCastOp<IntegerAttr, IntegerAttr>(
2019 adaptor.getOperands(), getType(),
2020 [resultBitwidth](const APInt &a, bool & /*castStatus*/) {
2021 return a.zextOrTrunc(resultBitwidth);
2022 }))
2023 return foldResult;
2024
2025 // index_castui(index_castui(x : A) : B) : A -> x, but only when B is at
2026 // least as wide as A. If B is narrower, the inner cast truncates and the
2027 // outer cast zero-extends, so the round-trip is lossy.
2028 if (auto inner = getOperand().getDefiningOp<arith::IndexCastUIOp>()) {
2029 Value x = inner.getOperand();
2030 if (x.getType() == getType()) {
2031 if (getIndexCastWidth(inner.getType()) >= getIndexCastWidth(x.getType()))
2032 return x;
2033 }
2034 }
2035 return {};
2036}
2037
2038void arith::IndexCastUIOp::getCanonicalizationPatterns(
2039 RewritePatternSet &patterns, MLIRContext *context) {
2040 patterns.add<IndexCastUIOfExtUI>(context);
2041}
2042
2043//===----------------------------------------------------------------------===//
2044// BitcastOp
2045//===----------------------------------------------------------------------===//
2046
2047bool arith::BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
2048 if (!areValidCastInputsAndOutputs(inputs, outputs))
2049 return false;
2050
2051 auto srcType = getTypeIfLikeOrMemRef<IntegerType, FloatType>(inputs.front());
2052 auto dstType = getTypeIfLikeOrMemRef<IntegerType, FloatType>(outputs.front());
2053 if (!srcType || !dstType)
2054 return false;
2055
2056 return srcType.getIntOrFloatBitWidth() == dstType.getIntOrFloatBitWidth();
2057}
2058
2059OpFoldResult arith::BitcastOp::fold(FoldAdaptor adaptor) {
2060 auto resType = getType();
2061 auto operand = adaptor.getIn();
2062 if (!operand)
2063 return {};
2064
2065 /// Bitcast dense elements.
2066 if (auto denseAttr = dyn_cast_or_null<DenseElementsAttr>(operand))
2067 return denseAttr.bitcast(llvm::cast<ShapedType>(resType).getElementType());
2068 /// Other shaped types unhandled.
2069 if (llvm::isa<ShapedType>(resType))
2070 return {};
2071
2072 /// Bitcast poison.
2073 if (matchPattern(operand, ub::m_Poison()))
2074 return ub::PoisonAttr::get(getContext());
2075
2076 /// Bitcast integer or float to integer or float.
2077 APInt bits = llvm::isa<FloatAttr>(operand)
2078 ? llvm::cast<FloatAttr>(operand).getValue().bitcastToAPInt()
2079 : llvm::cast<IntegerAttr>(operand).getValue();
2080 assert(resType.getIntOrFloatBitWidth() == bits.getBitWidth() &&
2081 "trying to fold on broken IR: operands have incompatible types");
2082
2083 if (auto resFloatType = dyn_cast<FloatType>(resType))
2084 return FloatAttr::get(resType,
2085 APFloat(resFloatType.getFloatSemantics(), bits));
2086 return IntegerAttr::get(resType, bits);
2087}
2088
2089void arith::BitcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
2090 MLIRContext *context) {
2091 patterns.add<BitcastOfBitcast>(context);
2092}
2093
2094//===----------------------------------------------------------------------===//
2095// CmpIOp
2096//===----------------------------------------------------------------------===//
2097
2098/// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer
2099/// comparison predicates.
2100bool mlir::arith::applyCmpPredicate(arith::CmpIPredicate predicate,
2101 const APInt &lhs, const APInt &rhs) {
2102 switch (predicate) {
2103 case arith::CmpIPredicate::eq:
2104 return lhs.eq(rhs);
2105 case arith::CmpIPredicate::ne:
2106 return lhs.ne(rhs);
2107 case arith::CmpIPredicate::slt:
2108 return lhs.slt(rhs);
2109 case arith::CmpIPredicate::sle:
2110 return lhs.sle(rhs);
2111 case arith::CmpIPredicate::sgt:
2112 return lhs.sgt(rhs);
2113 case arith::CmpIPredicate::sge:
2114 return lhs.sge(rhs);
2115 case arith::CmpIPredicate::ult:
2116 return lhs.ult(rhs);
2117 case arith::CmpIPredicate::ule:
2118 return lhs.ule(rhs);
2119 case arith::CmpIPredicate::ugt:
2120 return lhs.ugt(rhs);
2121 case arith::CmpIPredicate::uge:
2122 return lhs.uge(rhs);
2123 }
2124 llvm_unreachable("unknown cmpi predicate kind");
2125}
2126
2127/// Returns true if the predicate is true for two equal operands.
2128static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate) {
2129 switch (predicate) {
2130 case arith::CmpIPredicate::eq:
2131 case arith::CmpIPredicate::sle:
2132 case arith::CmpIPredicate::sge:
2133 case arith::CmpIPredicate::ule:
2134 case arith::CmpIPredicate::uge:
2135 return true;
2136 case arith::CmpIPredicate::ne:
2137 case arith::CmpIPredicate::slt:
2138 case arith::CmpIPredicate::sgt:
2139 case arith::CmpIPredicate::ult:
2140 case arith::CmpIPredicate::ugt:
2141 return false;
2142 }
2143 llvm_unreachable("unknown cmpi predicate kind");
2144}
2145
2146static std::optional<int64_t> getIntegerWidth(Type t) {
2147 if (auto intType = dyn_cast<IntegerType>(t)) {
2148 return intType.getWidth();
2149 }
2150 if (auto vectorIntType = dyn_cast<VectorType>(t)) {
2151 return llvm::cast<IntegerType>(vectorIntType.getElementType()).getWidth();
2152 }
2153 return std::nullopt;
2154}
2155
2156OpFoldResult arith::CmpIOp::fold(FoldAdaptor adaptor) {
2157 // cmpi(pred, x, x)
2158 if (getLhs() == getRhs()) {
2159 auto val = applyCmpPredicateToEqualOperands(getPredicate());
2160 return getBoolAttribute(getType(), val);
2161 }
2162
2163 if (matchPattern(adaptor.getRhs(), m_Zero())) {
2164 if (auto extOp = getLhs().getDefiningOp<ExtSIOp>()) {
2165 // extsi(%x : i1 -> iN) != 0 -> %x
2166 std::optional<int64_t> integerWidth =
2167 getIntegerWidth(extOp.getOperand().getType());
2168 if (integerWidth && integerWidth.value() == 1 &&
2169 getPredicate() == arith::CmpIPredicate::ne)
2170 return extOp.getOperand();
2171 }
2172 if (auto extOp = getLhs().getDefiningOp<ExtUIOp>()) {
2173 // extui(%x : i1 -> iN) != 0 -> %x
2174 std::optional<int64_t> integerWidth =
2175 getIntegerWidth(extOp.getOperand().getType());
2176 if (integerWidth && integerWidth.value() == 1 &&
2177 getPredicate() == arith::CmpIPredicate::ne)
2178 return extOp.getOperand();
2179 }
2180
2181 // arith.cmpi ne, %val, %zero : i1 -> %val
2182 if (getElementTypeOrSelf(getLhs().getType()).isInteger(1) &&
2183 getPredicate() == arith::CmpIPredicate::ne)
2184 return getLhs();
2185 }
2186
2187 if (matchPattern(adaptor.getRhs(), m_One())) {
2188 // arith.cmpi eq, %val, %one : i1 -> %val
2189 if (getElementTypeOrSelf(getLhs().getType()).isInteger(1) &&
2190 getPredicate() == arith::CmpIPredicate::eq)
2191 return getLhs();
2192 }
2193
2194 // Move constant to the right side.
2195 if (adaptor.getLhs() && !adaptor.getRhs()) {
2196 // Do not use invertPredicate, as it will change eq to ne and vice versa.
2197 using Pred = CmpIPredicate;
2198 const std::pair<Pred, Pred> invPreds[] = {
2199 {Pred::slt, Pred::sgt}, {Pred::sgt, Pred::slt}, {Pred::sle, Pred::sge},
2200 {Pred::sge, Pred::sle}, {Pred::ult, Pred::ugt}, {Pred::ugt, Pred::ult},
2201 {Pred::ule, Pred::uge}, {Pred::uge, Pred::ule}, {Pred::eq, Pred::eq},
2202 {Pred::ne, Pred::ne},
2203 };
2204 Pred origPred = getPredicate();
2205 for (auto pred : invPreds) {
2206 if (origPred == pred.first) {
2207 setPredicate(pred.second);
2208 Value lhs = getLhs();
2209 Value rhs = getRhs();
2210 getLhsMutable().assign(rhs);
2211 getRhsMutable().assign(lhs);
2212 return getResult();
2213 }
2214 }
2215 llvm_unreachable("unknown cmpi predicate kind");
2216 }
2217
2218 // We are moving constants to the right side; So if lhs is constant rhs is
2219 // guaranteed to be a constant.
2220 if (auto lhs = dyn_cast_if_present<TypedAttr>(adaptor.getLhs())) {
2222 adaptor.getOperands(), getI1SameShape(lhs.getType()),
2223 [pred = getPredicate()](const APInt &lhs, const APInt &rhs) {
2224 return APInt(1,
2225 static_cast<int64_t>(applyCmpPredicate(pred, lhs, rhs)));
2226 });
2227 }
2228
2229 return {};
2230}
2231
2232void arith::CmpIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
2233 MLIRContext *context) {
2234 patterns.insert<CmpIExtSI, CmpIExtUI>(context);
2235}
2236
2237//===----------------------------------------------------------------------===//
2238// CmpFOp
2239//===----------------------------------------------------------------------===//
2240
2241/// Compute `lhs` `pred` `rhs`, where `pred` is one of the known floating point
2242/// comparison predicates.
2243bool mlir::arith::applyCmpPredicate(arith::CmpFPredicate predicate,
2244 const APFloat &lhs, const APFloat &rhs) {
2245 auto cmpResult = lhs.compare(rhs);
2246 switch (predicate) {
2247 case arith::CmpFPredicate::AlwaysFalse:
2248 return false;
2249 case arith::CmpFPredicate::OEQ:
2250 return cmpResult == APFloat::cmpEqual;
2251 case arith::CmpFPredicate::OGT:
2252 return cmpResult == APFloat::cmpGreaterThan;
2253 case arith::CmpFPredicate::OGE:
2254 return cmpResult == APFloat::cmpGreaterThan ||
2255 cmpResult == APFloat::cmpEqual;
2256 case arith::CmpFPredicate::OLT:
2257 return cmpResult == APFloat::cmpLessThan;
2258 case arith::CmpFPredicate::OLE:
2259 return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
2260 case arith::CmpFPredicate::ONE:
2261 return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual;
2262 case arith::CmpFPredicate::ORD:
2263 return cmpResult != APFloat::cmpUnordered;
2264 case arith::CmpFPredicate::UEQ:
2265 return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual;
2266 case arith::CmpFPredicate::UGT:
2267 return cmpResult == APFloat::cmpUnordered ||
2268 cmpResult == APFloat::cmpGreaterThan;
2269 case arith::CmpFPredicate::UGE:
2270 return cmpResult == APFloat::cmpUnordered ||
2271 cmpResult == APFloat::cmpGreaterThan ||
2272 cmpResult == APFloat::cmpEqual;
2273 case arith::CmpFPredicate::ULT:
2274 return cmpResult == APFloat::cmpUnordered ||
2275 cmpResult == APFloat::cmpLessThan;
2276 case arith::CmpFPredicate::ULE:
2277 return cmpResult == APFloat::cmpUnordered ||
2278 cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
2279 case arith::CmpFPredicate::UNE:
2280 return cmpResult != APFloat::cmpEqual;
2281 case arith::CmpFPredicate::UNO:
2282 return cmpResult == APFloat::cmpUnordered;
2283 case arith::CmpFPredicate::AlwaysTrue:
2284 return true;
2285 }
2286 llvm_unreachable("unknown cmpf predicate kind");
2287}
2288
2289OpFoldResult arith::CmpFOp::fold(FoldAdaptor adaptor) {
2290 auto lhs = dyn_cast_if_present<FloatAttr>(adaptor.getLhs());
2291 auto rhs = dyn_cast_if_present<FloatAttr>(adaptor.getRhs());
2292
2293 // If one operand is NaN, making them both NaN does not change the result.
2294 if (lhs && lhs.getValue().isNaN())
2295 rhs = lhs;
2296 if (rhs && rhs.getValue().isNaN())
2297 lhs = rhs;
2298
2299 if (!lhs || !rhs)
2300 return {};
2301
2302 auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
2303 return BoolAttr::get(getContext(), val);
2304}
2305
2306class CmpFIntToFPConst final : public OpRewritePattern<CmpFOp> {
2307public:
2308 using Base::Base;
2309
2310 static CmpIPredicate convertToIntegerPredicate(CmpFPredicate pred,
2311 bool isUnsigned) {
2312 using namespace arith;
2313 switch (pred) {
2314 case CmpFPredicate::UEQ:
2315 case CmpFPredicate::OEQ:
2316 return CmpIPredicate::eq;
2317 case CmpFPredicate::UGT:
2318 case CmpFPredicate::OGT:
2319 return isUnsigned ? CmpIPredicate::ugt : CmpIPredicate::sgt;
2320 case CmpFPredicate::UGE:
2321 case CmpFPredicate::OGE:
2322 return isUnsigned ? CmpIPredicate::uge : CmpIPredicate::sge;
2323 case CmpFPredicate::ULT:
2324 case CmpFPredicate::OLT:
2325 return isUnsigned ? CmpIPredicate::ult : CmpIPredicate::slt;
2326 case CmpFPredicate::ULE:
2327 case CmpFPredicate::OLE:
2328 return isUnsigned ? CmpIPredicate::ule : CmpIPredicate::sle;
2329 case CmpFPredicate::UNE:
2330 case CmpFPredicate::ONE:
2331 return CmpIPredicate::ne;
2332 default:
2333 llvm_unreachable("Unexpected predicate!");
2334 }
2335 }
2336
2337 LogicalResult matchAndRewrite(CmpFOp op,
2338 PatternRewriter &rewriter) const override {
2339 FloatAttr flt;
2340 if (!matchPattern(op.getRhs(), m_Constant(&flt)))
2341 return failure();
2342
2343 const APFloat &rhs = flt.getValue();
2344
2345 // Don't attempt to fold a nan.
2346 if (rhs.isNaN())
2347 return failure();
2348
2349 // Get the width of the mantissa. We don't want to hack on conversions that
2350 // might lose information from the integer, e.g. "i64 -> float"
2351 FloatType floatTy = llvm::cast<FloatType>(op.getRhs().getType());
2352 int mantissaWidth = floatTy.getFPMantissaWidth();
2353 if (mantissaWidth <= 0)
2354 return failure();
2355
2356 bool isUnsigned;
2357 Value intVal;
2358
2359 if (auto si = op.getLhs().getDefiningOp<SIToFPOp>()) {
2360 isUnsigned = false;
2361 intVal = si.getIn();
2362 } else if (auto ui = op.getLhs().getDefiningOp<UIToFPOp>()) {
2363 isUnsigned = true;
2364 intVal = ui.getIn();
2365 } else {
2366 return failure();
2367 }
2368
2369 // Check to see that the input is converted from an integer type that is
2370 // small enough that preserves all bits.
2371 auto intTy = llvm::cast<IntegerType>(intVal.getType());
2372 auto intWidth = intTy.getWidth();
2373
2374 // Number of bits representing values, as opposed to the sign
2375 auto valueBits = isUnsigned ? intWidth : (intWidth - 1);
2376
2377 // Following test does NOT adjust intWidth downwards for signed inputs,
2378 // because the most negative value still requires all the mantissa bits
2379 // to distinguish it from one less than that value.
2380 if ((int)intWidth > mantissaWidth) {
2381 // Conversion would lose accuracy. Check if loss can impact comparison.
2382 int exponent = ilogb(rhs);
2383 if (exponent == APFloat::IEK_Inf) {
2384 int maxExponent = ilogb(APFloat::getLargest(rhs.getSemantics()));
2385 if (maxExponent < (int)valueBits) {
2386 // Conversion could create infinity.
2387 return failure();
2388 }
2389 } else {
2390 // Note that if rhs is zero or NaN, then Exp is negative
2391 // and first condition is trivially false.
2392 if (mantissaWidth <= exponent && exponent <= (int)valueBits) {
2393 // Conversion could affect comparison.
2394 return failure();
2395 }
2396 }
2397 }
2398
2399 // Convert to equivalent cmpi predicate
2400 CmpIPredicate pred;
2401 switch (op.getPredicate()) {
2402 case CmpFPredicate::ORD:
2403 // Int to fp conversion doesn't create a nan (ord checks neither is a nan)
2404 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2405 /*width=*/1);
2406 return success();
2407 case CmpFPredicate::UNO:
2408 // Int to fp conversion doesn't create a nan (uno checks either is a nan)
2409 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2410 /*width=*/1);
2411 return success();
2412 default:
2413 pred = convertToIntegerPredicate(op.getPredicate(), isUnsigned);
2414 break;
2415 }
2416
2417 if (!isUnsigned) {
2418 // If the rhs value is > SignedMax, fold the comparison. This handles
2419 // +INF and large values.
2420 APFloat signedMax(rhs.getSemantics());
2421 signedMax.convertFromAPInt(APInt::getSignedMaxValue(intWidth), true,
2422 APFloat::rmNearestTiesToEven);
2423 if (signedMax < rhs) { // smax < 13123.0
2424 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::slt ||
2425 pred == CmpIPredicate::sle)
2426 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2427 /*width=*/1);
2428 else
2429 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2430 /*width=*/1);
2431 return success();
2432 }
2433 } else {
2434 // If the rhs value is > UnsignedMax, fold the comparison. This handles
2435 // +INF and large values.
2436 APFloat unsignedMax(rhs.getSemantics());
2437 unsignedMax.convertFromAPInt(APInt::getMaxValue(intWidth), false,
2438 APFloat::rmNearestTiesToEven);
2439 if (unsignedMax < rhs) { // umax < 13123.0
2440 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ult ||
2441 pred == CmpIPredicate::ule)
2442 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2443 /*width=*/1);
2444 else
2445 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2446 /*width=*/1);
2447 return success();
2448 }
2449 }
2450
2451 if (!isUnsigned) {
2452 // See if the rhs value is < SignedMin.
2453 APFloat signedMin(rhs.getSemantics());
2454 signedMin.convertFromAPInt(APInt::getSignedMinValue(intWidth), true,
2455 APFloat::rmNearestTiesToEven);
2456 if (signedMin > rhs) { // smin > 12312.0
2457 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::sgt ||
2458 pred == CmpIPredicate::sge)
2459 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2460 /*width=*/1);
2461 else
2462 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2463 /*width=*/1);
2464 return success();
2465 }
2466 } else {
2467 // See if the rhs value is < UnsignedMin.
2468 APFloat unsignedMin(rhs.getSemantics());
2469 unsignedMin.convertFromAPInt(APInt::getMinValue(intWidth), false,
2470 APFloat::rmNearestTiesToEven);
2471 if (unsignedMin > rhs) { // umin > 12312.0
2472 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ugt ||
2473 pred == CmpIPredicate::uge)
2474 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2475 /*width=*/1);
2476 else
2477 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2478 /*width=*/1);
2479 return success();
2480 }
2481 }
2482
2483 // Okay, now we know that the FP constant fits in the range [SMIN, SMAX] or
2484 // [0, UMAX], but it may still be fractional. See if it is fractional by
2485 // casting the FP value to the integer value and back, checking for
2486 // equality. Don't do this for zero, because -0.0 is not fractional.
2487 bool ignored;
2488 APSInt rhsInt(intWidth, isUnsigned);
2489 if (APFloat::opInvalidOp ==
2490 rhs.convertToInteger(rhsInt, APFloat::rmTowardZero, &ignored)) {
2491 // Undefined behavior invoked - the destination type can't represent
2492 // the input constant.
2493 return failure();
2494 }
2495
2496 if (!rhs.isZero()) {
2497 APFloat apf(floatTy.getFloatSemantics(),
2498 APInt::getZero(floatTy.getWidth()));
2499 apf.convertFromAPInt(rhsInt, !isUnsigned, APFloat::rmNearestTiesToEven);
2500
2501 bool equal = apf == rhs;
2502 if (!equal) {
2503 // If we had a comparison against a fractional value, we have to adjust
2504 // the compare predicate and sometimes the value. rhsInt is rounded
2505 // towards zero at this point.
2506 switch (pred) {
2507 case CmpIPredicate::ne: // (float)int != 4.4 --> true
2508 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2509 /*width=*/1);
2510 return success();
2511 case CmpIPredicate::eq: // (float)int == 4.4 --> false
2512 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2513 /*width=*/1);
2514 return success();
2515 case CmpIPredicate::ule:
2516 // (float)int <= 4.4 --> int <= 4
2517 // (float)int <= -4.4 --> false
2518 if (rhs.isNegative()) {
2519 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2520 /*width=*/1);
2521 return success();
2522 }
2523 break;
2524 case CmpIPredicate::sle:
2525 // (float)int <= 4.4 --> int <= 4
2526 // (float)int <= -4.4 --> int < -4
2527 if (rhs.isNegative())
2528 pred = CmpIPredicate::slt;
2529 break;
2530 case CmpIPredicate::ult:
2531 // (float)int < -4.4 --> false
2532 // (float)int < 4.4 --> int <= 4
2533 if (rhs.isNegative()) {
2534 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2535 /*width=*/1);
2536 return success();
2537 }
2538 pred = CmpIPredicate::ule;
2539 break;
2540 case CmpIPredicate::slt:
2541 // (float)int < -4.4 --> int < -4
2542 // (float)int < 4.4 --> int <= 4
2543 if (!rhs.isNegative())
2544 pred = CmpIPredicate::sle;
2545 break;
2546 case CmpIPredicate::ugt:
2547 // (float)int > 4.4 --> int > 4
2548 // (float)int > -4.4 --> true
2549 if (rhs.isNegative()) {
2550 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2551 /*width=*/1);
2552 return success();
2553 }
2554 break;
2555 case CmpIPredicate::sgt:
2556 // (float)int > 4.4 --> int > 4
2557 // (float)int > -4.4 --> int >= -4
2558 if (rhs.isNegative())
2559 pred = CmpIPredicate::sge;
2560 break;
2561 case CmpIPredicate::uge:
2562 // (float)int >= -4.4 --> true
2563 // (float)int >= 4.4 --> int > 4
2564 if (rhs.isNegative()) {
2565 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2566 /*width=*/1);
2567 return success();
2568 }
2569 pred = CmpIPredicate::ugt;
2570 break;
2571 case CmpIPredicate::sge:
2572 // (float)int >= -4.4 --> int >= -4
2573 // (float)int >= 4.4 --> int > 4
2574 if (!rhs.isNegative())
2575 pred = CmpIPredicate::sgt;
2576 break;
2577 }
2578 }
2579 }
2580
2581 // Lower this FP comparison into an appropriate integer version of the
2582 // comparison.
2583 rewriter.replaceOpWithNewOp<CmpIOp>(
2584 op, pred, intVal,
2585 ConstantOp::create(rewriter, op.getLoc(), intVal.getType(),
2586 rewriter.getIntegerAttr(intVal.getType(), rhsInt)));
2587 return success();
2588 }
2589};
2590
2591void arith::CmpFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
2592 MLIRContext *context) {
2593 patterns.insert<CmpFIntToFPConst>(context);
2594}
2595
2596//===----------------------------------------------------------------------===//
2597// SelectOp
2598//===----------------------------------------------------------------------===//
2599
2600// select %arg, %c1, %c0 => extui %arg
2601struct SelectToExtUI : public OpRewritePattern<arith::SelectOp> {
2602 using Base::Base;
2603
2604 LogicalResult matchAndRewrite(arith::SelectOp op,
2605 PatternRewriter &rewriter) const override {
2606 // Cannot extui i1 to i1, or i1 to f32
2607 if (!llvm::isa<IntegerType>(op.getType()) || op.getType().isInteger(1))
2608 return failure();
2609
2610 // select %x, c1, %c0 => extui %arg
2611 if (matchPattern(op.getTrueValue(), m_One()) &&
2612 matchPattern(op.getFalseValue(), m_Zero())) {
2613 rewriter.replaceOpWithNewOp<arith::ExtUIOp>(op, op.getType(),
2614 op.getCondition());
2615 return success();
2616 }
2617
2618 // select %x, c0, %c1 => extui (xor %arg, true)
2619 if (matchPattern(op.getTrueValue(), m_Zero()) &&
2620 matchPattern(op.getFalseValue(), m_One())) {
2621 rewriter.replaceOpWithNewOp<arith::ExtUIOp>(
2622 op, op.getType(),
2623 arith::XOrIOp::create(
2624 rewriter, op.getLoc(), op.getCondition(),
2625 arith::ConstantIntOp::create(rewriter, op.getLoc(),
2626 op.getCondition().getType(), 1)));
2627 return success();
2628 }
2629
2630 return failure();
2631 }
2632};
2633
2634void arith::SelectOp::getCanonicalizationPatterns(RewritePatternSet &results,
2635 MLIRContext *context) {
2636 results.add<RedundantSelectFalse, RedundantSelectTrue, SelectNotCond,
2637 SelectI1ToNot, SelectToExtUI>(context);
2638}
2639
2640OpFoldResult arith::SelectOp::fold(FoldAdaptor adaptor) {
2641 Value trueVal = getTrueValue();
2642 Value falseVal = getFalseValue();
2643 if (trueVal == falseVal)
2644 return trueVal;
2645
2646 Value condition = getCondition();
2647
2648 // select true, %0, %1 => %0
2649 if (matchPattern(adaptor.getCondition(), m_One()))
2650 return trueVal;
2651
2652 // select false, %0, %1 => %1
2653 if (matchPattern(adaptor.getCondition(), m_Zero()))
2654 return falseVal;
2655
2656 // If either operand is fully poisoned, return the other.
2657 if (matchPattern(adaptor.getTrueValue(), ub::m_Poison()))
2658 return falseVal;
2659
2660 if (matchPattern(adaptor.getFalseValue(), ub::m_Poison()))
2661 return trueVal;
2662
2663 // select %x, true, false => %x
2664 if (getType().isSignlessInteger(1) &&
2665 matchPattern(adaptor.getTrueValue(), m_One()) &&
2666 matchPattern(adaptor.getFalseValue(), m_Zero()))
2667 return condition;
2668
2669 if (auto cmp = condition.getDefiningOp<arith::CmpIOp>()) {
2670 auto pred = cmp.getPredicate();
2671 if (pred == arith::CmpIPredicate::eq || pred == arith::CmpIPredicate::ne) {
2672 auto cmpLhs = cmp.getLhs();
2673 auto cmpRhs = cmp.getRhs();
2674
2675 // %0 = arith.cmpi eq, %arg0, %arg1
2676 // %1 = arith.select %0, %arg0, %arg1 => %arg1
2677
2678 // %0 = arith.cmpi ne, %arg0, %arg1
2679 // %1 = arith.select %0, %arg0, %arg1 => %arg0
2680
2681 if ((cmpLhs == trueVal && cmpRhs == falseVal) ||
2682 (cmpRhs == trueVal && cmpLhs == falseVal))
2683 return pred == arith::CmpIPredicate::ne ? trueVal : falseVal;
2684 }
2685 }
2686
2687 // Constant-fold constant operands over non-splat constant condition.
2688 // select %cst_vec, %cst0, %cst1 => %cst2
2689 if (auto cond =
2690 dyn_cast_if_present<DenseElementsAttr>(adaptor.getCondition())) {
2691 // DenseElementsAttr by construction always has a static shape.
2692 assert(cond.getType().hasStaticShape() &&
2693 "DenseElementsAttr must have static shape");
2694 if (auto lhs =
2695 dyn_cast_if_present<DenseElementsAttr>(adaptor.getTrueValue())) {
2696 if (auto rhs =
2697 dyn_cast_if_present<DenseElementsAttr>(adaptor.getFalseValue())) {
2698 SmallVector<Attribute> results;
2699 results.reserve(static_cast<size_t>(cond.getNumElements()));
2700 auto condVals = llvm::make_range(cond.value_begin<BoolAttr>(),
2701 cond.value_end<BoolAttr>());
2702 auto lhsVals = llvm::make_range(lhs.value_begin<Attribute>(),
2703 lhs.value_end<Attribute>());
2704 auto rhsVals = llvm::make_range(rhs.value_begin<Attribute>(),
2705 rhs.value_end<Attribute>());
2706
2707 for (auto [condVal, lhsVal, rhsVal] :
2708 llvm::zip_equal(condVals, lhsVals, rhsVals))
2709 results.push_back(condVal.getValue() ? lhsVal : rhsVal);
2710
2711 return DenseElementsAttr::get(lhs.getType(), results);
2712 }
2713 }
2714 }
2715
2716 return nullptr;
2717}
2718
2719ParseResult SelectOp::parse(OpAsmParser &parser, OperationState &result) {
2720 Type conditionType, resultType;
2721 SmallVector<OpAsmParser::UnresolvedOperand, 3> operands;
2722 if (parser.parseOperandList(operands, /*requiredOperandCount=*/3) ||
2723 parser.parseOptionalAttrDict(result.attributes) ||
2724 parser.parseColonType(resultType))
2725 return failure();
2726
2727 // Check for the explicit condition type if this is a masked tensor or vector.
2728 if (succeeded(parser.parseOptionalComma())) {
2729 conditionType = resultType;
2730 if (parser.parseType(resultType))
2731 return failure();
2732 } else {
2733 conditionType = parser.getBuilder().getI1Type();
2734 }
2735
2736 result.addTypes(resultType);
2737 return parser.resolveOperands(operands,
2738 {conditionType, resultType, resultType},
2739 parser.getNameLoc(), result.operands);
2740}
2741
2742void arith::SelectOp::print(OpAsmPrinter &p) {
2743 p << " " << getOperands();
2744 p.printOptionalAttrDict((*this)->getAttrs());
2745 p << " : ";
2746 if (ShapedType condType = dyn_cast<ShapedType>(getCondition().getType()))
2747 p << condType << ", ";
2748 p << getType();
2749}
2750
2751LogicalResult arith::SelectOp::verify() {
2752 Type conditionType = getCondition().getType();
2753 if (conditionType.isSignlessInteger(1))
2754 return success();
2755
2756 // If the result type is a vector or tensor, the type can be a mask with the
2757 // same elements.
2758 Type resultType = getType();
2759 if (!llvm::isa<TensorType, VectorType>(resultType))
2760 return emitOpError() << "expected condition to be a signless i1, but got "
2761 << conditionType;
2762 Type shapedConditionType = getI1SameShape(resultType);
2763 if (conditionType != shapedConditionType) {
2764 return emitOpError() << "expected condition type to have the same shape "
2765 "as the result type, expected "
2766 << shapedConditionType << ", but got "
2767 << conditionType;
2768 }
2769 return success();
2770}
2771//===----------------------------------------------------------------------===//
2772// ShLIOp
2773//===----------------------------------------------------------------------===//
2774
2775OpFoldResult arith::ShLIOp::fold(FoldAdaptor adaptor) {
2776 // shli(x, 0) -> x
2777 if (matchPattern(adaptor.getRhs(), m_Zero()))
2778 return getLhs();
2779 // Don't fold if shifting more or equal than the bit width.
2780 bool bounded = false;
2782 adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
2783 bounded = b.ult(b.getBitWidth());
2784 return a.shl(b);
2785 });
2786 return bounded ? result : Attribute();
2787}
2788
2789//===----------------------------------------------------------------------===//
2790// ShRUIOp
2791//===----------------------------------------------------------------------===//
2792
2793OpFoldResult arith::ShRUIOp::fold(FoldAdaptor adaptor) {
2794 // shrui(x, 0) -> x
2795 if (matchPattern(adaptor.getRhs(), m_Zero()))
2796 return getLhs();
2797 // Don't fold if shifting more or equal than the bit width.
2798 bool bounded = false;
2800 adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
2801 bounded = b.ult(b.getBitWidth());
2802 return a.lshr(b);
2803 });
2804 return bounded ? result : Attribute();
2805}
2806
2807//===----------------------------------------------------------------------===//
2808// ShRSIOp
2809//===----------------------------------------------------------------------===//
2810
2811OpFoldResult arith::ShRSIOp::fold(FoldAdaptor adaptor) {
2812 // shrsi(x, 0) -> x
2813 if (matchPattern(adaptor.getRhs(), m_Zero()))
2814 return getLhs();
2815 // Don't fold if shifting more or equal than the bit width.
2816 bool bounded = false;
2818 adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
2819 bounded = b.ult(b.getBitWidth());
2820 return a.ashr(b);
2821 });
2822 return bounded ? result : Attribute();
2823}
2824
2825//===----------------------------------------------------------------------===//
2826// Atomic Enum
2827//===----------------------------------------------------------------------===//
2828
2829/// Returns the identity value attribute associated with an AtomicRMWKind op.
2830TypedAttr mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
2831 OpBuilder &builder, Location loc,
2832 bool useOnlyFiniteValue) {
2833 switch (kind) {
2834 case AtomicRMWKind::maximumf: {
2835 const llvm::fltSemantics &semantic =
2836 llvm::cast<FloatType>(resultType).getFloatSemantics();
2837 APFloat identity = useOnlyFiniteValue
2838 ? APFloat::getLargest(semantic, /*Negative=*/true)
2839 : APFloat::getInf(semantic, /*Negative=*/true);
2840 return builder.getFloatAttr(resultType, identity);
2841 }
2842 case AtomicRMWKind::maxnumf: {
2843 const llvm::fltSemantics &semantic =
2844 llvm::cast<FloatType>(resultType).getFloatSemantics();
2845 APFloat identity = APFloat::getNaN(semantic, /*Negative=*/true);
2846 return builder.getFloatAttr(resultType, identity);
2847 }
2848 case AtomicRMWKind::addf:
2849 case AtomicRMWKind::addi:
2850 case AtomicRMWKind::maxu:
2851 case AtomicRMWKind::ori:
2852 case AtomicRMWKind::xori:
2853 return builder.getZeroAttr(resultType);
2854 case AtomicRMWKind::andi:
2855 return builder.getIntegerAttr(
2856 resultType,
2857 APInt::getAllOnes(llvm::cast<IntegerType>(resultType).getWidth()));
2858 case AtomicRMWKind::maxs:
2859 return builder.getIntegerAttr(
2860 resultType, APInt::getSignedMinValue(
2861 llvm::cast<IntegerType>(resultType).getWidth()));
2862 case AtomicRMWKind::minimumf: {
2863 const llvm::fltSemantics &semantic =
2864 llvm::cast<FloatType>(resultType).getFloatSemantics();
2865 APFloat identity = useOnlyFiniteValue
2866 ? APFloat::getLargest(semantic, /*Negative=*/false)
2867 : APFloat::getInf(semantic, /*Negative=*/false);
2868
2869 return builder.getFloatAttr(resultType, identity);
2870 }
2871 case AtomicRMWKind::minnumf: {
2872 const llvm::fltSemantics &semantic =
2873 llvm::cast<FloatType>(resultType).getFloatSemantics();
2874 APFloat identity = APFloat::getNaN(semantic, /*Negative=*/false);
2875 return builder.getFloatAttr(resultType, identity);
2876 }
2877 case AtomicRMWKind::mins:
2878 return builder.getIntegerAttr(
2879 resultType, APInt::getSignedMaxValue(
2880 llvm::cast<IntegerType>(resultType).getWidth()));
2881 case AtomicRMWKind::minu:
2882 return builder.getIntegerAttr(
2883 resultType,
2884 APInt::getMaxValue(llvm::cast<IntegerType>(resultType).getWidth()));
2885 case AtomicRMWKind::muli:
2886 return builder.getIntegerAttr(resultType, 1);
2887 case AtomicRMWKind::mulf:
2888 return builder.getFloatAttr(resultType, 1);
2889 // TODO: Add remaining reduction operations.
2890 default:
2891 (void)emitOptionalError(loc, "Reduction operation type not supported");
2892 break;
2893 }
2894 return nullptr;
2895}
2896
2897/// Returns the identity numeric value of the given op.
2898std::optional<TypedAttr> mlir::arith::getNeutralElement(Operation *op) {
2899 std::optional<AtomicRMWKind> maybeKind =
2901 // Floating-point operations.
2902 .Case([](arith::AddFOp op) { return AtomicRMWKind::addf; })
2903 .Case([](arith::MulFOp op) { return AtomicRMWKind::mulf; })
2904 .Case([](arith::MaximumFOp op) { return AtomicRMWKind::maximumf; })
2905 .Case([](arith::MinimumFOp op) { return AtomicRMWKind::minimumf; })
2906 .Case([](arith::MaxNumFOp op) { return AtomicRMWKind::maxnumf; })
2907 .Case([](arith::MinNumFOp op) { return AtomicRMWKind::minnumf; })
2908 // Integer operations.
2909 .Case([](arith::AddIOp op) { return AtomicRMWKind::addi; })
2910 .Case([](arith::OrIOp op) { return AtomicRMWKind::ori; })
2911 .Case([](arith::XOrIOp op) { return AtomicRMWKind::xori; })
2912 .Case([](arith::AndIOp op) { return AtomicRMWKind::andi; })
2913 .Case([](arith::MaxUIOp op) { return AtomicRMWKind::maxu; })
2914 .Case([](arith::MinUIOp op) { return AtomicRMWKind::minu; })
2915 .Case([](arith::MaxSIOp op) { return AtomicRMWKind::maxs; })
2916 .Case([](arith::MinSIOp op) { return AtomicRMWKind::mins; })
2917 .Case([](arith::MulIOp op) { return AtomicRMWKind::muli; })
2918 .Default(std::nullopt);
2919 if (!maybeKind) {
2920 return std::nullopt;
2921 }
2922
2923 bool useOnlyFiniteValue = false;
2924 auto fmfOpInterface = dyn_cast<ArithFastMathInterface>(op);
2925 if (fmfOpInterface) {
2926 arith::FastMathFlagsAttr fmfAttr = fmfOpInterface.getFastMathFlagsAttr();
2927 useOnlyFiniteValue =
2928 bitEnumContainsAny(fmfAttr.getValue(), arith::FastMathFlags::ninf);
2929 }
2930
2931 // Builder only used as helper for attribute creation.
2932 OpBuilder b(op->getContext());
2933 Type resultType = op->getResult(0).getType();
2934
2935 return getIdentityValueAttr(*maybeKind, resultType, b, op->getLoc(),
2936 useOnlyFiniteValue);
2937}
2938
2939/// Returns the identity value associated with an AtomicRMWKind op.
2940Value mlir::arith::getIdentityValue(AtomicRMWKind op, Type resultType,
2941 OpBuilder &builder, Location loc,
2942 bool useOnlyFiniteValue) {
2943 if (auto attr = getIdentityValueAttr(op, resultType, builder, loc,
2944 useOnlyFiniteValue))
2945 return arith::ConstantOp::create(builder, loc, attr);
2946 return {};
2947}
2948
2949/// Return the value obtained by applying the reduction operation kind
2950/// associated with a binary AtomicRMWKind op to `lhs` and `rhs`.
2952 Location loc, Value lhs, Value rhs) {
2953 switch (op) {
2954 case AtomicRMWKind::addf:
2955 return arith::AddFOp::create(builder, loc, lhs, rhs);
2956 case AtomicRMWKind::addi:
2957 return arith::AddIOp::create(builder, loc, lhs, rhs);
2958 case AtomicRMWKind::mulf:
2959 return arith::MulFOp::create(builder, loc, lhs, rhs);
2960 case AtomicRMWKind::muli:
2961 return arith::MulIOp::create(builder, loc, lhs, rhs);
2962 case AtomicRMWKind::maximumf:
2963 return arith::MaximumFOp::create(builder, loc, lhs, rhs);
2964 case AtomicRMWKind::minimumf:
2965 return arith::MinimumFOp::create(builder, loc, lhs, rhs);
2966 case AtomicRMWKind::maxnumf:
2967 return arith::MaxNumFOp::create(builder, loc, lhs, rhs);
2968 case AtomicRMWKind::minnumf:
2969 return arith::MinNumFOp::create(builder, loc, lhs, rhs);
2970 case AtomicRMWKind::maxs:
2971 return arith::MaxSIOp::create(builder, loc, lhs, rhs);
2972 case AtomicRMWKind::mins:
2973 return arith::MinSIOp::create(builder, loc, lhs, rhs);
2974 case AtomicRMWKind::maxu:
2975 return arith::MaxUIOp::create(builder, loc, lhs, rhs);
2976 case AtomicRMWKind::minu:
2977 return arith::MinUIOp::create(builder, loc, lhs, rhs);
2978 case AtomicRMWKind::ori:
2979 return arith::OrIOp::create(builder, loc, lhs, rhs);
2980 case AtomicRMWKind::andi:
2981 return arith::AndIOp::create(builder, loc, lhs, rhs);
2982 case AtomicRMWKind::xori:
2983 return arith::XOrIOp::create(builder, loc, lhs, rhs);
2984 // TODO: Add remaining reduction operations.
2985 default:
2986 (void)emitOptionalError(loc, "Reduction operation type not supported");
2987 break;
2988 }
2989 return nullptr;
2990}
2991
2992//===----------------------------------------------------------------------===//
2993// TableGen'd op method definitions
2994//===----------------------------------------------------------------------===//
2995
2996#define GET_OP_CLASSES
2997#include "mlir/Dialect/Arith/IR/ArithOps.cpp.inc"
2998
2999//===----------------------------------------------------------------------===//
3000// TableGen'd enum attribute definitions
3001//===----------------------------------------------------------------------===//
3002
3003#include "mlir/Dialect/Arith/IR/ArithOpsEnums.cpp.inc"
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static Speculation::Speculatability getDivUISpeculatability(Value divisor)
Returns whether an unsigned division by divisor is speculatable.
Definition ArithOps.cpp:728
static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs)
Validate a cast that changes the width of a type.
static IntegerAttr mulIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
Definition ArithOps.cpp:65
static IntegerOverflowFlagsAttr mergeOverflowFlags(IntegerOverflowFlagsAttr val1, IntegerOverflowFlagsAttr val2)
Definition ArithOps.cpp:72
static constexpr llvm::RoundingMode kDefaultRoundingMode
Default rounding mode according to default LLVM floating-point environment.
Definition ArithOps.cpp:38
static Type getTypeIfLike(Type type)
Get allowed underlying types for vectors and tensors.
static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate)
Returns true if the predicate is true for two equal operands.
static FailureOr< APFloat > convertFloatValue(APFloat sourceValue, const llvm::fltSemantics &targetSemantics, llvm::RoundingMode roundingMode=kDefaultRoundingMode)
Attempts to convert sourceValue to an APFloat value with targetSemantics and roundingMode,...
static Value foldDivMul(Value lhs, Value rhs, arith::IntegerOverflowFlags ovfFlags)
Fold (a * b) / b -> a
Definition ArithOps.cpp:689
static bool hasSameEncoding(Type typeA, Type typeB)
Return false if both types are ranked tensor with mismatching encoding.
static llvm::RoundingMode convertArithRoundingModeToLLVMIR(std::optional< RoundingMode > roundingMode)
Equivalent to convertRoundingModeToLLVM(convertArithRoundingModeToLLVM(roundingMode)).
Definition ArithOps.cpp:112
static Type getUnderlyingType(Type type, type_list< ShapedTypes... >, type_list< ElementTypes... >)
Returns a non-null type only if the provided type is one of the allowed types or one of the allowed s...
static APInt signedCeilNonnegInputs(const APInt &a, const APInt &b, bool &overflow)
Definition ArithOps.cpp:788
static std::optional< int64_t > getIntegerWidth(Type t)
static Speculation::Speculatability getDivSISpeculatability(Value divisor)
Returns whether a signed division by divisor is speculatable.
Definition ArithOps.cpp:770
static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
Definition ArithOps.cpp:55
static Attribute getBoolAttribute(Type type, bool value)
Definition ArithOps.cpp:155
static bool areIndexCastCompatible(TypeRange inputs, TypeRange outputs)
static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs)
static LogicalResult verifyExtOp(Op op)
static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
Definition ArithOps.cpp:60
static int64_t getScalarOrElementWidth(Type type)
Definition ArithOps.cpp:135
static Value foldAndIofAndI(arith::AndIOp op)
Fold and(a, and(a, b)) to and(a, b)
Definition ArithOps.cpp:982
static Type getTypeIfLikeOrMemRef(Type type)
Get allowed underlying types for vectors, tensors, and memrefs.
static Type getI1SameShape(Type type)
Return the type of the same shape (scalar, vector or tensor) containing i1.
Definition ArithOps.cpp:179
static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs)
std::tuple< Types... > * type_list
static IntegerAttr applyToIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs, function_ref< APInt(const APInt &, const APInt &)> binFn)
Definition ArithOps.cpp:46
static APInt calculateUnsignedOverflow(const APInt &sum, const APInt &operand)
Definition ArithOps.cpp:443
static FailureOr< APInt > getIntOrSplatIntValue(Attribute attr)
Definition ArithOps.cpp:147
static unsigned getIndexCastWidth(Type t)
Return the bit-width of t for the purpose of index_cast width checks.
static LogicalResult verifyTruncateOp(Op op)
lhs
static Type getElementType(Type type)
Determine the element type of type.
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
if(!isCopyOut)
b getContext())
#define mul(a, b)
#define add(a, b)
#define div(a, b)
LogicalResult matchAndRewrite(CmpFOp op, PatternRewriter &rewriter) const override
static CmpIPredicate convertToIntegerPredicate(CmpFPredicate pred, bool isUnsigned)
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
Attributes are known-constant values of operations.
Definition Attributes.h:25
static BoolAttr get(MLIRContext *context, bool value)
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:112
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:232
FloatAttr getFloatAttr(Type type, double value)
Definition Builders.cpp:258
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:71
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition Builders.h:93
TypedAttr getZeroAttr(Type type)
Definition Builders.cpp:328
IntegerType getI1Type()
Definition Builders.cpp:57
IndexType getIndexType()
Definition Builders.cpp:55
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Definition Builders.h:632
Location getLoc() const
Accessors for the implied location.
Definition Builders.h:665
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
This class helps build Operations.
Definition Builders.h:209
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition Builders.cpp:461
This class represents a single result from folding an operation.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:433
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:241
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:234
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:40
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition Types.cpp:35
bool isSignlessInteger() const
Return true if this is a signless integer type (with the specified width).
Definition Types.cpp:66
bool isIndex() const
Definition Types.cpp:56
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition Types.cpp:118
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:124
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
Specialization of arith.constant op that returns a floating point value.
Definition Arith.h:92
static ConstantFloatOp create(OpBuilder &builder, Location location, FloatType type, const APFloat &value)
Definition ArithOps.cpp:340
static bool classof(Operation *op)
Definition ArithOps.cpp:357
static void build(OpBuilder &builder, OperationState &result, FloatType type, const APFloat &value)
Build a constant float op that produces a float of the specified type.
Definition ArithOps.cpp:334
Specialization of arith.constant op that returns an integer of index type.
Definition Arith.h:113
static void build(OpBuilder &builder, OperationState &result, int64_t value)
Build a constant int op that produces an index.
Definition ArithOps.cpp:363
static bool classof(Operation *op)
Definition ArithOps.cpp:384
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:369
Specialization of arith.constant op that returns an integer value.
Definition Arith.h:54
static ConstantIntOp create(OpBuilder &builder, Location location, int64_t value, unsigned width)
Definition ArithOps.cpp:268
static void build(OpBuilder &builder, OperationState &result, int64_t value, unsigned width)
Build a constant int op that produces an integer of the specified width.
Definition ArithOps.cpp:261
static bool classof(Operation *op)
Definition ArithOps.cpp:328
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
std::optional< TypedAttr > getNeutralElement(Operation *op)
Return the identity numeric value associated to the give op.
bool applyCmpPredicate(arith::CmpIPredicate predicate, const APInt &lhs, const APInt &rhs)
Compute lhs pred rhs, where pred is one of the known integer comparison predicates.
TypedAttr getIdentityValueAttr(AtomicRMWKind kind, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value attribute associated with an AtomicRMWKind op.
Value getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc, Value lhs, Value rhs)
Returns the value obtained by applying the reduction operation kind associated with a binary AtomicRM...
Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value associated with an AtomicRMWKind op.
arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred)
Invert an integer comparison predicate.
Definition ArithOps.cpp:79
Value getZeroConstant(OpBuilder &builder, Location loc, Type type)
Creates an arith.constant operation with a zero value of type type.
Definition ArithOps.cpp:390
auto m_Val(Value v)
Definition Matchers.h:539
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
detail::poison_attr_matcher m_Poison()
Matches a poison constant (any attribute implementing PoisonAttrInterface).
Definition UBMatchers.h:46
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition Matchers.h:527
detail::constant_float_predicate_matcher m_NaNFloat()
Matches a constant scalar / vector splat / tensor splat float ones.
Definition Matchers.h:421
LogicalResult verifyCompatibleShapes(TypeRange types1, TypeRange types2)
Returns success if the given two arrays have the same number of elements and each pair wise entries h...
Attribute constFoldCastOp(ArrayRef< Attribute > operands, Type resType, CalculationT &&calculate)
Attribute constFoldBinaryOp(ArrayRef< Attribute > operands, Type resultType, CalculationT &&calculate)
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
detail::constant_int_range_predicate_matcher m_IntRangeWithoutNegOneS()
Matches a constant scalar / vector splat / tensor splat integer or a signed integer range that does n...
Definition Matchers.h:471
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
detail::constant_float_predicate_matcher m_PosZeroFloat()
Matches a constant scalar / vector splat / tensor splat float positive zero.
Definition Matchers.h:404
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Definition Matchers.h:442
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
detail::constant_float_predicate_matcher m_AnyZeroFloat()
Matches a constant scalar / vector splat / tensor splat float (both positive and negative) zero.
Definition Matchers.h:399
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
Definition Matchers.h:478
detail::constant_float_predicate_matcher m_NegInfFloat()
Matches a constant scalar / vector splat / tensor splat float negative infinity.
Definition Matchers.h:435
detail::constant_float_predicate_matcher m_NegZeroFloat()
Matches a constant scalar / vector splat / tensor splat float negative zero.
Definition Matchers.h:409
detail::constant_int_range_predicate_matcher m_IntRangeWithoutZeroS()
Matches a constant scalar / vector splat / tensor splat integer or a signed integer range that does n...
Definition Matchers.h:462
detail::op_matcher< OpClass > m_Op()
Matches the given OpClass.
Definition Matchers.h:484
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
Attribute constFoldUnaryOp(ArrayRef< Attribute > operands, Type resultType, CalculationT &&calculate)
detail::constant_float_predicate_matcher m_PosInfFloat()
Matches a constant scalar / vector splat / tensor splat float positive infinity.
Definition Matchers.h:427
llvm::function_ref< Fn > function_ref
Definition LLVM.h:147
detail::constant_float_predicate_matcher m_OneFloat()
Matches a constant scalar / vector splat / tensor splat float ones.
Definition Matchers.h:414
detail::constant_int_range_predicate_matcher m_IntRangeWithoutZeroU()
Matches a constant scalar / vector splat / tensor splat integer or a unsigned integer range that does...
Definition Matchers.h:455
LogicalResult matchAndRewrite(arith::SelectOp op, PatternRewriter &rewriter) const override
OpRewritePattern Base
Type alias to allow derived classes to inherit constructors with using Base::Base;.
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
This represents an operation in an abstracted form, suitable for use with the builder APIs.