MLIR 23.0.0git
ArithOps.cpp
Go to the documentation of this file.
1//===- ArithOps.cpp - MLIR Arith dialect ops implementation -----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include <cassert>
10#include <cstdint>
11#include <functional>
12#include <utility>
13
17#include "mlir/IR/Builders.h"
20#include "mlir/IR/Matchers.h"
25
26#include "llvm/ADT/APFloat.h"
27#include "llvm/ADT/APInt.h"
28#include "llvm/ADT/APSInt.h"
29#include "llvm/ADT/FloatingPointMode.h"
30#include "llvm/ADT/STLExtras.h"
31#include "llvm/ADT/SmallVector.h"
32#include "llvm/ADT/TypeSwitch.h"
33
34using namespace mlir;
35using namespace mlir::arith;
36
37//===----------------------------------------------------------------------===//
38// Pattern helpers
39//===----------------------------------------------------------------------===//
40
41static IntegerAttr
44 function_ref<APInt(const APInt &, const APInt &)> binFn) {
45 const APInt &lhsVal = llvm::cast<IntegerAttr>(lhs).getValue();
46 const APInt &rhsVal = llvm::cast<IntegerAttr>(rhs).getValue();
47 APInt value = binFn(lhsVal, rhsVal);
48 return IntegerAttr::get(res.getType(), value);
49}
50
51static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res,
53 return applyToIntegerAttrs(builder, res, lhs, rhs, std::plus<APInt>());
54}
55
56static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res,
58 return applyToIntegerAttrs(builder, res, lhs, rhs, std::minus<APInt>());
59}
60
61static IntegerAttr mulIntegerAttrs(PatternRewriter &builder, Value res,
63 return applyToIntegerAttrs(builder, res, lhs, rhs, std::multiplies<APInt>());
64}
65
66// Merge overflow flags from 2 ops, selecting the most conservative combination.
67static IntegerOverflowFlagsAttr
68mergeOverflowFlags(IntegerOverflowFlagsAttr val1,
69 IntegerOverflowFlagsAttr val2) {
70 return IntegerOverflowFlagsAttr::get(val1.getContext(),
71 val1.getValue() & val2.getValue());
72}
73
74/// Invert an integer comparison predicate.
75arith::CmpIPredicate arith::invertPredicate(arith::CmpIPredicate pred) {
76 switch (pred) {
77 case arith::CmpIPredicate::eq:
78 return arith::CmpIPredicate::ne;
79 case arith::CmpIPredicate::ne:
80 return arith::CmpIPredicate::eq;
81 case arith::CmpIPredicate::slt:
82 return arith::CmpIPredicate::sge;
83 case arith::CmpIPredicate::sle:
84 return arith::CmpIPredicate::sgt;
85 case arith::CmpIPredicate::sgt:
86 return arith::CmpIPredicate::sle;
87 case arith::CmpIPredicate::sge:
88 return arith::CmpIPredicate::slt;
89 case arith::CmpIPredicate::ult:
90 return arith::CmpIPredicate::uge;
91 case arith::CmpIPredicate::ule:
92 return arith::CmpIPredicate::ugt;
93 case arith::CmpIPredicate::ugt:
94 return arith::CmpIPredicate::ule;
95 case arith::CmpIPredicate::uge:
96 return arith::CmpIPredicate::ult;
97 }
98 llvm_unreachable("unknown cmpi predicate kind");
99}
100
101/// Equivalent to
102/// convertRoundingModeToLLVM(convertArithRoundingModeToLLVM(roundingMode)).
103///
104/// Not possible to implement as chain of calls as this would introduce a
105/// circular dependency with MLIRArithAttrToLLVMConversion and make arith depend
106/// on the LLVM dialect and on translation to LLVM.
107static llvm::RoundingMode
108convertArithRoundingModeToLLVMIR(RoundingMode roundingMode) {
109 switch (roundingMode) {
110 case RoundingMode::downward:
111 return llvm::RoundingMode::TowardNegative;
112 case RoundingMode::to_nearest_away:
113 return llvm::RoundingMode::NearestTiesToAway;
114 case RoundingMode::to_nearest_even:
115 return llvm::RoundingMode::NearestTiesToEven;
116 case RoundingMode::toward_zero:
117 return llvm::RoundingMode::TowardZero;
118 case RoundingMode::upward:
119 return llvm::RoundingMode::TowardPositive;
120 }
121 llvm_unreachable("Unhandled rounding mode");
122}
123
124static arith::CmpIPredicateAttr invertPredicate(arith::CmpIPredicateAttr pred) {
125 return arith::CmpIPredicateAttr::get(pred.getContext(),
126 invertPredicate(pred.getValue()));
127}
128
130 Type elemTy = getElementTypeOrSelf(type);
131 if (elemTy.isIntOrFloat())
132 return elemTy.getIntOrFloatBitWidth();
133
134 return -1;
135}
136
138 return getScalarOrElementWidth(value.getType());
139}
140
141static FailureOr<APInt> getIntOrSplatIntValue(Attribute attr) {
142 APInt value;
143 if (matchPattern(attr, m_ConstantInt(&value)))
144 return value;
145
146 return failure();
147}
148
149static Attribute getBoolAttribute(Type type, bool value) {
150 auto boolAttr = BoolAttr::get(type.getContext(), value);
151 ShapedType shapedType = dyn_cast_or_null<ShapedType>(type);
152 if (!shapedType)
153 return boolAttr;
154 // DenseElementsAttr requires a static shape.
155 if (!shapedType.hasStaticShape())
156 return {};
157 return DenseElementsAttr::get(shapedType, boolAttr);
158}
159
160//===----------------------------------------------------------------------===//
161// TableGen'd canonicalization patterns
162//===----------------------------------------------------------------------===//
163
164namespace {
165#include "ArithCanonicalization.inc"
166} // namespace
167
168//===----------------------------------------------------------------------===//
169// Common helpers
170//===----------------------------------------------------------------------===//
171
172/// Return the type of the same shape (scalar, vector or tensor) containing i1.
174 auto i1Type = IntegerType::get(type.getContext(), 1);
175 if (auto shapedType = dyn_cast<ShapedType>(type))
176 return shapedType.cloneWith(std::nullopt, i1Type);
177 if (llvm::isa<UnrankedTensorType>(type))
178 return UnrankedTensorType::get(i1Type);
179 return i1Type;
180}
181
182//===----------------------------------------------------------------------===//
183// ConstantOp
184//===----------------------------------------------------------------------===//
185
186void arith::ConstantOp::getAsmResultNames(
187 function_ref<void(Value, StringRef)> setNameFn) {
188 auto type = getType();
189 if (auto intCst = dyn_cast<IntegerAttr>(getValue())) {
190 auto intType = dyn_cast<IntegerType>(type);
191
192 // Sugar i1 constants with 'true' and 'false'.
193 if (intType && intType.getWidth() == 1)
194 return setNameFn(getResult(), (intCst.getInt() ? "true" : "false"));
195
196 // Otherwise, build a complex name with the value and type.
197 SmallString<32> specialNameBuffer;
198 llvm::raw_svector_ostream specialName(specialNameBuffer);
199 specialName << 'c' << intCst.getValue();
200 if (intType)
201 specialName << '_' << type;
202 setNameFn(getResult(), specialName.str());
203 } else {
204 setNameFn(getResult(), "cst");
205 }
206}
207
208/// TODO: disallow arith.constant to return anything other than signless integer
209/// or float like.
210LogicalResult arith::ConstantOp::verify() {
211 auto type = getType();
212 // Integer values must be signless.
213 if (llvm::isa<IntegerType>(type) &&
214 !llvm::cast<IntegerType>(type).isSignless())
215 return emitOpError("integer return type must be signless");
216 // Any float or elements attribute are acceptable.
217 if (!llvm::isa<IntegerAttr, FloatAttr, ElementsAttr>(getValue())) {
218 return emitOpError(
219 "value must be an integer, float, or elements attribute");
220 }
221
222 // Note, we could relax this for vectors with 1 scalable dim, e.g.:
223 // * arith.constant dense<[[3, 3], [1, 1]]> : vector<2 x [2] x i32>
224 // However, this would most likely require updating the lowerings to LLVM.
225 if (isa<ScalableVectorType>(type) && !isa<SplatElementsAttr>(getValue()))
226 return emitOpError(
227 "initializing scalable vectors with elements attribute is not supported"
228 " unless it's a vector splat");
229 return success();
230}
231
232bool arith::ConstantOp::isBuildableWith(Attribute value, Type type) {
233 // The value's type must be the same as the provided type.
234 auto typedAttr = dyn_cast<TypedAttr>(value);
235 if (!typedAttr || typedAttr.getType() != type)
236 return false;
237 // Integer values must be signless.
238 if (auto intType = dyn_cast<IntegerType>(getElementTypeOrSelf(type))) {
239 if (!intType.isSignless())
240 return false;
241 }
242 // Integer, float, and element attributes are buildable.
243 return llvm::isa<IntegerAttr, FloatAttr, ElementsAttr>(value);
244}
245
246ConstantOp arith::ConstantOp::materialize(OpBuilder &builder, Attribute value,
247 Type type, Location loc) {
248 if (isBuildableWith(value, type))
249 return arith::ConstantOp::create(builder, loc, cast<TypedAttr>(value));
250 return nullptr;
251}
252
253OpFoldResult arith::ConstantOp::fold(FoldAdaptor adaptor) { return getValue(); }
254
256 int64_t value, unsigned width) {
257 auto type = builder.getIntegerType(width);
258 arith::ConstantOp::build(builder, result, type,
259 builder.getIntegerAttr(type, value));
260}
261
263 Location location,
265 unsigned width) {
266 mlir::OperationState state(location, getOperationName());
267 build(builder, state, value, width);
268 auto result = dyn_cast<ConstantIntOp>(builder.create(state));
269 assert(result && "builder didn't return the right type");
270 return result;
271}
272
275 unsigned width) {
276 return create(builder, builder.getLoc(), value, width);
277}
278
280 Type type, int64_t value) {
281 arith::ConstantOp::build(builder, result, type,
282 builder.getIntegerAttr(type, value));
283}
284
286 Location location, Type type,
287 int64_t value) {
288 mlir::OperationState state(location, getOperationName());
289 build(builder, state, type, value);
290 auto result = dyn_cast<ConstantIntOp>(builder.create(state));
291 assert(result && "builder didn't return the right type");
292 return result;
293}
294
296 Type type, int64_t value) {
297 return create(builder, builder.getLoc(), type, value);
298}
299
301 Type type, const APInt &value) {
302 arith::ConstantOp::build(builder, result, type,
303 builder.getIntegerAttr(type, value));
304}
305
307 Location location, Type type,
308 const APInt &value) {
309 mlir::OperationState state(location, getOperationName());
310 build(builder, state, type, value);
311 auto result = dyn_cast<ConstantIntOp>(builder.create(state));
312 assert(result && "builder didn't return the right type");
313 return result;
314}
315
317 Type type,
318 const APInt &value) {
319 return create(builder, builder.getLoc(), type, value);
320}
321
323 if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
324 return constOp.getType().isSignlessInteger();
325 return false;
326}
327
329 FloatType type, const APFloat &value) {
330 arith::ConstantOp::build(builder, result, type,
331 builder.getFloatAttr(type, value));
332}
333
335 Location location,
336 FloatType type,
337 const APFloat &value) {
338 mlir::OperationState state(location, getOperationName());
339 build(builder, state, type, value);
340 auto result = dyn_cast<ConstantFloatOp>(builder.create(state));
341 assert(result && "builder didn't return the right type");
342 return result;
343}
344
347 const APFloat &value) {
348 return create(builder, builder.getLoc(), type, value);
349}
350
352 if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
353 return llvm::isa<FloatType>(constOp.getType());
354 return false;
355}
356
358 int64_t value) {
359 arith::ConstantOp::build(builder, result, builder.getIndexType(),
360 builder.getIndexAttr(value));
361}
362
364 Location location,
365 int64_t value) {
366 mlir::OperationState state(location, getOperationName());
367 build(builder, state, value);
368 auto result = dyn_cast<ConstantIndexOp>(builder.create(state));
369 assert(result && "builder didn't return the right type");
370 return result;
371}
372
375 return create(builder, builder.getLoc(), value);
376}
377
379 if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
380 return constOp.getType().isIndex();
381 return false;
382}
383
385 Type type) {
386 // TODO: Incorporate this check to `FloatAttr::get*`.
387 assert(!isa<Float8E8M0FNUType>(getElementTypeOrSelf(type)) &&
388 "type doesn't have a zero representation");
389 TypedAttr zeroAttr = builder.getZeroAttr(type);
390 assert(zeroAttr && "unsupported type for zero attribute");
391 return arith::ConstantOp::create(builder, loc, zeroAttr);
392}
393
394//===----------------------------------------------------------------------===//
395// AddIOp
396//===----------------------------------------------------------------------===//
397
398OpFoldResult arith::AddIOp::fold(FoldAdaptor adaptor) {
399 // addi(x, 0) -> x
400 if (matchPattern(adaptor.getRhs(), m_Zero()))
401 return getLhs();
402
403 // addi(subi(a, b), b) -> a
404 if (auto sub = getLhs().getDefiningOp<SubIOp>())
405 if (getRhs() == sub.getRhs())
406 return sub.getLhs();
407
408 // addi(b, subi(a, b)) -> a
409 if (auto sub = getRhs().getDefiningOp<SubIOp>())
410 if (getLhs() == sub.getRhs())
411 return sub.getLhs();
412
414 adaptor.getOperands(),
415 [](APInt a, const APInt &b) { return std::move(a) + b; });
416}
417
418void arith::AddIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
419 MLIRContext *context) {
420 patterns.add<AddIAddConstant, AddISubConstantRHS, AddISubConstantLHS,
421 AddIMulNegativeOneRhs, AddIMulNegativeOneLhs>(context);
422}
423
424//===----------------------------------------------------------------------===//
425// AddUIExtendedOp
426//===----------------------------------------------------------------------===//
427
428std::optional<SmallVector<int64_t, 4>>
429arith::AddUIExtendedOp::getShapeForUnroll() {
430 if (auto vt = dyn_cast<VectorType>(getType(0)))
431 return llvm::to_vector<4>(vt.getShape());
432 return std::nullopt;
433}
434
435// Returns the overflow bit, assuming that `sum` is the result of unsigned
436// addition of `operand` and another number.
437static APInt calculateUnsignedOverflow(const APInt &sum, const APInt &operand) {
438 return sum.ult(operand) ? APInt::getAllOnes(1) : APInt::getZero(1);
439}
440
441LogicalResult
442arith::AddUIExtendedOp::fold(FoldAdaptor adaptor,
443 SmallVectorImpl<OpFoldResult> &results) {
444 Type overflowTy = getOverflow().getType();
445 // addui_extended(x, 0) -> x, false
446 if (matchPattern(getRhs(), m_Zero())) {
447 Builder builder(getContext());
448 auto falseValue = builder.getZeroAttr(overflowTy);
449
450 results.push_back(getLhs());
451 results.push_back(falseValue);
452 return success();
453 }
454
455 // addui_extended(constant_a, constant_b) -> constant_sum, constant_carry
456 // Let the `constFoldBinaryOp` utility attempt to fold the sum of both
457 // operands. If that succeeds, calculate the overflow bit based on the sum
458 // and the first (constant) operand, `lhs`.
459 if (Attribute sumAttr = constFoldBinaryOp<IntegerAttr>(
460 adaptor.getOperands(),
461 [](APInt a, const APInt &b) { return std::move(a) + b; })) {
462 // If any operand is poison, propagate poison to both results.
463 if (matchPattern(sumAttr, ub::m_Poison())) {
464 results.push_back(sumAttr);
465 results.push_back(sumAttr);
466 return success();
467 }
468 Attribute overflowAttr = constFoldBinaryOp<IntegerAttr>(
469 ArrayRef({sumAttr, adaptor.getLhs()}),
470 getI1SameShape(llvm::cast<TypedAttr>(sumAttr).getType()),
472 if (!overflowAttr)
473 return failure();
474
475 results.push_back(sumAttr);
476 results.push_back(overflowAttr);
477 return success();
478 }
479
480 return failure();
481}
482
483void arith::AddUIExtendedOp::getCanonicalizationPatterns(
484 RewritePatternSet &patterns, MLIRContext *context) {
485 patterns.add<AddUIExtendedToAddI>(context);
486}
487
488//===----------------------------------------------------------------------===//
489// SubIOp
490//===----------------------------------------------------------------------===//
491
492OpFoldResult arith::SubIOp::fold(FoldAdaptor adaptor) {
493 // subi(x,x) -> 0
494 if (getOperand(0) == getOperand(1)) {
495 auto shapedType = dyn_cast<ShapedType>(getType());
496 // We can't generate a constant with a dynamic shaped tensor.
497 if (!shapedType || shapedType.hasStaticShape())
498 return Builder(getContext()).getZeroAttr(getType());
499 }
500 // subi(x,0) -> x
501 if (matchPattern(adaptor.getRhs(), m_Zero()))
502 return getLhs();
503
504 if (auto add = getLhs().getDefiningOp<AddIOp>()) {
505 // subi(addi(a, b), b) -> a
506 if (getRhs() == add.getRhs())
507 return add.getLhs();
508 // subi(addi(a, b), a) -> b
509 if (getRhs() == add.getLhs())
510 return add.getRhs();
511 }
512
514 adaptor.getOperands(),
515 [](APInt a, const APInt &b) { return std::move(a) - b; });
516}
517
518void arith::SubIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
519 MLIRContext *context) {
520 patterns.add<SubIRHSAddConstant, SubILHSAddConstant, SubIRHSSubConstantRHS,
521 SubIRHSSubConstantLHS, SubILHSSubConstantRHS,
522 SubILHSSubConstantLHS, SubISubILHSRHSLHS>(context);
523}
524
525//===----------------------------------------------------------------------===//
526// MulIOp
527//===----------------------------------------------------------------------===//
528
529OpFoldResult arith::MulIOp::fold(FoldAdaptor adaptor) {
530 // muli(x, 0) -> 0
531 if (matchPattern(adaptor.getRhs(), m_Zero()))
532 return getRhs();
533 // muli(x, 1) -> x
534 if (matchPattern(adaptor.getRhs(), m_One()))
535 return getLhs();
536 // TODO: Handle the overflow case.
537
538 // default folder
540 adaptor.getOperands(),
541 [](const APInt &a, const APInt &b) { return a * b; });
542}
543
544void arith::MulIOp::getAsmResultNames(
545 function_ref<void(Value, StringRef)> setNameFn) {
546 if (!isa<IndexType>(getType()))
547 return;
548
549 // Match vector.vscale by name to avoid depending on the vector dialect (which
550 // is a circular dependency).
551 auto isVscale = [](Operation *op) {
552 return op && op->getName().getStringRef() == "vector.vscale";
553 };
554
555 IntegerAttr baseValue;
556 auto isVscaleExpr = [&](Value a, Value b) {
557 return matchPattern(a, m_Constant(&baseValue)) &&
558 isVscale(b.getDefiningOp());
559 };
560
561 if (!isVscaleExpr(getLhs(), getRhs()) && !isVscaleExpr(getRhs(), getLhs()))
562 return;
563
564 // Name `base * vscale` or `vscale * base` as `c<base_value>_vscale`.
565 SmallString<32> specialNameBuffer;
566 llvm::raw_svector_ostream specialName(specialNameBuffer);
567 specialName << 'c' << baseValue.getInt() << "_vscale";
568 setNameFn(getResult(), specialName.str());
569}
570
571void arith::MulIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
572 MLIRContext *context) {
573 patterns.add<MulIMulIConstant>(context);
574}
575
576//===----------------------------------------------------------------------===//
577// MulSIExtendedOp
578//===----------------------------------------------------------------------===//
579
580std::optional<SmallVector<int64_t, 4>>
581arith::MulSIExtendedOp::getShapeForUnroll() {
582 if (auto vt = dyn_cast<VectorType>(getType(0)))
583 return llvm::to_vector<4>(vt.getShape());
584 return std::nullopt;
585}
586
587LogicalResult
588arith::MulSIExtendedOp::fold(FoldAdaptor adaptor,
589 SmallVectorImpl<OpFoldResult> &results) {
590 // mulsi_extended(x, 0) -> 0, 0
591 if (matchPattern(adaptor.getRhs(), m_Zero())) {
592 Attribute zero = adaptor.getRhs();
593 results.push_back(zero);
594 results.push_back(zero);
595 return success();
596 }
597
598 // mulsi_extended(cst_a, cst_b) -> cst_low, cst_high
599 if (Attribute lowAttr = constFoldBinaryOp<IntegerAttr>(
600 adaptor.getOperands(),
601 [](const APInt &a, const APInt &b) { return a * b; })) {
602 // Invoke the constant fold helper again to calculate the 'high' result.
603 Attribute highAttr = constFoldBinaryOp<IntegerAttr>(
604 adaptor.getOperands(), [](const APInt &a, const APInt &b) {
605 return llvm::APIntOps::mulhs(a, b);
606 });
607 assert(highAttr && "Unexpected constant-folding failure");
608
609 results.push_back(lowAttr);
610 results.push_back(highAttr);
611 return success();
612 }
613
614 return failure();
615}
616
617void arith::MulSIExtendedOp::getCanonicalizationPatterns(
618 RewritePatternSet &patterns, MLIRContext *context) {
619 patterns.add<MulSIExtendedToMulI, MulSIExtendedRHSOne>(context);
620}
621
622//===----------------------------------------------------------------------===//
623// MulUIExtendedOp
624//===----------------------------------------------------------------------===//
625
626std::optional<SmallVector<int64_t, 4>>
627arith::MulUIExtendedOp::getShapeForUnroll() {
628 if (auto vt = dyn_cast<VectorType>(getType(0)))
629 return llvm::to_vector<4>(vt.getShape());
630 return std::nullopt;
631}
632
633LogicalResult
634arith::MulUIExtendedOp::fold(FoldAdaptor adaptor,
635 SmallVectorImpl<OpFoldResult> &results) {
636 // mului_extended(x, 0) -> 0, 0
637 if (matchPattern(adaptor.getRhs(), m_Zero())) {
638 Attribute zero = adaptor.getRhs();
639 results.push_back(zero);
640 results.push_back(zero);
641 return success();
642 }
643
644 // mului_extended(x, 1) -> x, 0
645 if (matchPattern(adaptor.getRhs(), m_One())) {
646 Builder builder(getContext());
647 Attribute zero = builder.getZeroAttr(getLhs().getType());
648 results.push_back(getLhs());
649 results.push_back(zero);
650 return success();
651 }
652
653 // mului_extended(cst_a, cst_b) -> cst_low, cst_high
654 if (Attribute lowAttr = constFoldBinaryOp<IntegerAttr>(
655 adaptor.getOperands(),
656 [](const APInt &a, const APInt &b) { return a * b; })) {
657 // Invoke the constant fold helper again to calculate the 'high' result.
658 Attribute highAttr = constFoldBinaryOp<IntegerAttr>(
659 adaptor.getOperands(), [](const APInt &a, const APInt &b) {
660 return llvm::APIntOps::mulhu(a, b);
661 });
662 assert(highAttr && "Unexpected constant-folding failure");
663
664 results.push_back(lowAttr);
665 results.push_back(highAttr);
666 return success();
667 }
668
669 return failure();
670}
671
672void arith::MulUIExtendedOp::getCanonicalizationPatterns(
673 RewritePatternSet &patterns, MLIRContext *context) {
674 patterns.add<MulUIExtendedToMulI>(context);
675}
676
677//===----------------------------------------------------------------------===//
678// DivUIOp
679//===----------------------------------------------------------------------===//
680
681/// Fold `(a * b) / b -> a`
683 arith::IntegerOverflowFlags ovfFlags) {
684 auto mul = lhs.getDefiningOp<mlir::arith::MulIOp>();
685 if (!mul || !bitEnumContainsAll(mul.getOverflowFlags(), ovfFlags))
686 return {};
687
688 if (mul.getLhs() == rhs)
689 return mul.getRhs();
690
691 if (mul.getRhs() == rhs)
692 return mul.getLhs();
693
694 return {};
695}
696
697OpFoldResult arith::DivUIOp::fold(FoldAdaptor adaptor) {
698 // divui (x, 1) -> x.
699 if (matchPattern(adaptor.getRhs(), m_One()))
700 return getLhs();
701
702 // (a * b) / b -> a
703 if (Value val = foldDivMul(getLhs(), getRhs(), IntegerOverflowFlags::nuw))
704 return val;
705
706 // Don't fold if it would require a division by zero.
707 bool div0 = false;
708 auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
709 [&](APInt a, const APInt &b) {
710 if (div0 || !b) {
711 div0 = true;
712 return a;
713 }
714 return a.udiv(b);
715 });
716
717 return div0 ? Attribute() : result;
718}
719
720/// Returns whether an unsigned division by `divisor` is speculatable.
722 // X / 0 => UB
723 if (matchPattern(divisor, m_IntRangeWithoutZeroU()))
725
727}
728
729Speculation::Speculatability arith::DivUIOp::getSpeculatability() {
730 return getDivUISpeculatability(getRhs());
731}
732
733//===----------------------------------------------------------------------===//
734// DivSIOp
735//===----------------------------------------------------------------------===//
736
737OpFoldResult arith::DivSIOp::fold(FoldAdaptor adaptor) {
738 // divsi (x, 1) -> x.
739 if (matchPattern(adaptor.getRhs(), m_One()))
740 return getLhs();
741
742 // (a * b) / b -> a
743 if (Value val = foldDivMul(getLhs(), getRhs(), IntegerOverflowFlags::nsw))
744 return val;
745
746 // Don't fold if it would overflow or if it requires a division by zero.
747 bool overflowOrDiv0 = false;
749 adaptor.getOperands(), [&](APInt a, const APInt &b) {
750 if (overflowOrDiv0 || !b) {
751 overflowOrDiv0 = true;
752 return a;
753 }
754 return a.sdiv_ov(b, overflowOrDiv0);
755 });
756
757 return overflowOrDiv0 ? Attribute() : result;
758}
759
760/// Returns whether a signed division by `divisor` is speculatable. This
761/// function conservatively assumes that all signed division by -1 are not
762/// speculatable.
764 // X / 0 => UB
765 // INT_MIN / -1 => UB
766 if (matchPattern(divisor, m_IntRangeWithoutZeroS()) &&
769
771}
772
773Speculation::Speculatability arith::DivSIOp::getSpeculatability() {
774 return getDivSISpeculatability(getRhs());
775}
776
777//===----------------------------------------------------------------------===//
778// Ceil and floor division folding helpers
779//===----------------------------------------------------------------------===//
780
781static APInt signedCeilNonnegInputs(const APInt &a, const APInt &b,
782 bool &overflow) {
783 // Returns (a-1)/b + 1
784 APInt one(a.getBitWidth(), 1, true); // Signed value 1.
785 APInt val = a.ssub_ov(one, overflow).sdiv_ov(b, overflow);
786 return val.sadd_ov(one, overflow);
787}
788
789//===----------------------------------------------------------------------===//
790// CeilDivUIOp
791//===----------------------------------------------------------------------===//
792
793OpFoldResult arith::CeilDivUIOp::fold(FoldAdaptor adaptor) {
794 // ceildivui (x, 1) -> x.
795 if (matchPattern(adaptor.getRhs(), m_One()))
796 return getLhs();
797
798 bool overflowOrDiv0 = false;
800 adaptor.getOperands(), [&](APInt a, const APInt &b) {
801 if (overflowOrDiv0 || !b) {
802 overflowOrDiv0 = true;
803 return a;
804 }
805 APInt quotient = a.udiv(b);
806 if (!a.urem(b))
807 return quotient;
808 APInt one(a.getBitWidth(), 1, true);
809 return quotient.uadd_ov(one, overflowOrDiv0);
810 });
811
812 return overflowOrDiv0 ? Attribute() : result;
813}
814
815Speculation::Speculatability arith::CeilDivUIOp::getSpeculatability() {
816 return getDivUISpeculatability(getRhs());
817}
818
819//===----------------------------------------------------------------------===//
820// CeilDivSIOp
821//===----------------------------------------------------------------------===//
822
823OpFoldResult arith::CeilDivSIOp::fold(FoldAdaptor adaptor) {
824 // ceildivsi (x, 1) -> x.
825 if (matchPattern(adaptor.getRhs(), m_One()))
826 return getLhs();
827
828 // Don't fold if it would overflow or if it requires a division by zero.
829 // TODO: This hook won't fold operations where a = MININT, because
830 // negating MININT overflows. This can be improved.
831 bool overflowOrDiv0 = false;
833 adaptor.getOperands(), [&](APInt a, const APInt &b) {
834 if (overflowOrDiv0 || !b) {
835 overflowOrDiv0 = true;
836 return a;
837 }
838 if (!a)
839 return a;
840 // After this point we know that neither a or b are zero.
841 unsigned bits = a.getBitWidth();
842 APInt zero = APInt::getZero(bits);
843 bool aGtZero = a.sgt(zero);
844 bool bGtZero = b.sgt(zero);
845 if (aGtZero && bGtZero) {
846 // Both positive, return ceil(a, b).
847 return signedCeilNonnegInputs(a, b, overflowOrDiv0);
848 }
849
850 // No folding happens if any of the intermediate arithmetic operations
851 // overflows.
852 bool overflowNegA = false;
853 bool overflowNegB = false;
854 bool overflowDiv = false;
855 bool overflowNegRes = false;
856 if (!aGtZero && !bGtZero) {
857 // Both negative, return ceil(-a, -b).
858 APInt posA = zero.ssub_ov(a, overflowNegA);
859 APInt posB = zero.ssub_ov(b, overflowNegB);
860 APInt res = signedCeilNonnegInputs(posA, posB, overflowDiv);
861 overflowOrDiv0 = (overflowNegA || overflowNegB || overflowDiv);
862 return res;
863 }
864 if (!aGtZero && bGtZero) {
865 // A is negative, b is positive, return - ( -a / b).
866 APInt posA = zero.ssub_ov(a, overflowNegA);
867 APInt div = posA.sdiv_ov(b, overflowDiv);
868 APInt res = zero.ssub_ov(div, overflowNegRes);
869 overflowOrDiv0 = (overflowNegA || overflowDiv || overflowNegRes);
870 return res;
871 }
872 // A is positive, b is negative, return - (a / -b).
873 APInt posB = zero.ssub_ov(b, overflowNegB);
874 APInt div = a.sdiv_ov(posB, overflowDiv);
875 APInt res = zero.ssub_ov(div, overflowNegRes);
876
877 overflowOrDiv0 = (overflowNegB || overflowDiv || overflowNegRes);
878 return res;
879 });
880
881 return overflowOrDiv0 ? Attribute() : result;
882}
883
884Speculation::Speculatability arith::CeilDivSIOp::getSpeculatability() {
885 return getDivSISpeculatability(getRhs());
886}
887
888//===----------------------------------------------------------------------===//
889// FloorDivSIOp
890//===----------------------------------------------------------------------===//
891
892OpFoldResult arith::FloorDivSIOp::fold(FoldAdaptor adaptor) {
893 // floordivsi (x, 1) -> x.
894 if (matchPattern(adaptor.getRhs(), m_One()))
895 return getLhs();
896
897 // Don't fold if it would overflow or if it requires a division by zero.
898 bool overflowOrDiv = false;
900 adaptor.getOperands(), [&](APInt a, const APInt &b) {
901 if (b.isZero()) {
902 overflowOrDiv = true;
903 return a;
904 }
905 return a.sfloordiv_ov(b, overflowOrDiv);
906 });
907
908 return overflowOrDiv ? Attribute() : result;
909}
910
911//===----------------------------------------------------------------------===//
912// RemUIOp
913//===----------------------------------------------------------------------===//
914
915OpFoldResult arith::RemUIOp::fold(FoldAdaptor adaptor) {
916 // remui (x, 1) -> 0.
917 if (matchPattern(adaptor.getRhs(), m_One()))
918 return Builder(getContext()).getZeroAttr(getType());
919
920 // Don't fold if it would require a division by zero.
921 bool div0 = false;
922 auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
923 [&](APInt a, const APInt &b) {
924 if (div0 || b.isZero()) {
925 div0 = true;
926 return a;
927 }
928 return a.urem(b);
929 });
930
931 return div0 ? Attribute() : result;
932}
933
934Speculation::Speculatability arith::RemUIOp::getSpeculatability() {
935 return getDivUISpeculatability(getRhs());
936}
937
938//===----------------------------------------------------------------------===//
939// RemSIOp
940//===----------------------------------------------------------------------===//
941
942OpFoldResult arith::RemSIOp::fold(FoldAdaptor adaptor) {
943 // remsi (x, 1) -> 0.
944 if (matchPattern(adaptor.getRhs(), m_One()))
945 return Builder(getContext()).getZeroAttr(getType());
946
947 // Don't fold if it would require a division by zero.
948 bool div0 = false;
949 auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
950 [&](APInt a, const APInt &b) {
951 if (div0 || b.isZero()) {
952 div0 = true;
953 return a;
954 }
955 return a.srem(b);
956 });
957
958 return div0 ? Attribute() : result;
959}
960
961Speculation::Speculatability arith::RemSIOp::getSpeculatability() {
962 // X % 0 => UB
963 // X % -1 is well-defined (always 0), unlike X / -1 which can overflow.
964 if (matchPattern(getRhs(), m_IntRangeWithoutZeroS()))
966
968}
969
970//===----------------------------------------------------------------------===//
971// AndIOp
972//===----------------------------------------------------------------------===//
973
974/// Fold `and(a, and(a, b))` to `and(a, b)`
975static Value foldAndIofAndI(arith::AndIOp op) {
976 for (bool reversePrev : {false, true}) {
977 auto prev = (reversePrev ? op.getRhs() : op.getLhs())
978 .getDefiningOp<arith::AndIOp>();
979 if (!prev)
980 continue;
981
982 Value other = (reversePrev ? op.getLhs() : op.getRhs());
983 if (other != prev.getLhs() && other != prev.getRhs())
984 continue;
985
986 return prev.getResult();
987 }
988 return {};
989}
990
991OpFoldResult arith::AndIOp::fold(FoldAdaptor adaptor) {
992 /// and(x, 0) -> 0
993 if (matchPattern(adaptor.getRhs(), m_Zero()))
994 return getRhs();
995 /// and(x, allOnes) -> x
996 APInt intValue;
997 if (matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue)) &&
998 intValue.isAllOnes())
999 return getLhs();
1000 /// and(x, not(x)) -> 0
1001 if (matchPattern(getRhs(), m_Op<XOrIOp>(matchers::m_Val(getLhs()),
1002 m_ConstantInt(&intValue))) &&
1003 intValue.isAllOnes())
1004 return Builder(getContext()).getZeroAttr(getType());
1005 /// and(not(x), x) -> 0
1006 if (matchPattern(getLhs(), m_Op<XOrIOp>(matchers::m_Val(getRhs()),
1007 m_ConstantInt(&intValue))) &&
1008 intValue.isAllOnes())
1009 return Builder(getContext()).getZeroAttr(getType());
1010
1011 /// and(a, and(a, b)) -> and(a, b)
1012 if (Value result = foldAndIofAndI(*this))
1013 return result;
1014
1016 adaptor.getOperands(),
1017 [](APInt a, const APInt &b) { return std::move(a) & b; });
1018}
1019
1020//===----------------------------------------------------------------------===//
1021// OrIOp
1022//===----------------------------------------------------------------------===//
1023
1024OpFoldResult arith::OrIOp::fold(FoldAdaptor adaptor) {
1025 if (APInt rhsVal; matchPattern(adaptor.getRhs(), m_ConstantInt(&rhsVal))) {
1026 /// or(x, 0) -> x
1027 if (rhsVal.isZero())
1028 return getLhs();
1029 /// or(x, <all ones>) -> <all ones>
1030 if (rhsVal.isAllOnes())
1031 return adaptor.getRhs();
1032 }
1033
1034 APInt intValue;
1035 /// or(x, xor(x, 1)) -> 1
1036 if (matchPattern(getRhs(), m_Op<XOrIOp>(matchers::m_Val(getLhs()),
1037 m_ConstantInt(&intValue))) &&
1038 intValue.isAllOnes())
1039 return getRhs().getDefiningOp<XOrIOp>().getRhs();
1040 /// or(xor(x, 1), x) -> 1
1041 if (matchPattern(getLhs(), m_Op<XOrIOp>(matchers::m_Val(getRhs()),
1042 m_ConstantInt(&intValue))) &&
1043 intValue.isAllOnes())
1044 return getLhs().getDefiningOp<XOrIOp>().getRhs();
1045
1047 adaptor.getOperands(),
1048 [](APInt a, const APInt &b) { return std::move(a) | b; });
1049}
1050
1051//===----------------------------------------------------------------------===//
1052// XOrIOp
1053//===----------------------------------------------------------------------===//
1054
1055OpFoldResult arith::XOrIOp::fold(FoldAdaptor adaptor) {
1056 /// xor(x, 0) -> x
1057 if (matchPattern(adaptor.getRhs(), m_Zero()))
1058 return getLhs();
1059 /// xor(x, x) -> 0
1060 if (getLhs() == getRhs())
1061 return Builder(getContext()).getZeroAttr(getType());
1062 /// xor(xor(x, a), a) -> x
1063 /// xor(xor(a, x), a) -> x
1064 if (arith::XOrIOp prev = getLhs().getDefiningOp<arith::XOrIOp>()) {
1065 if (prev.getRhs() == getRhs())
1066 return prev.getLhs();
1067 if (prev.getLhs() == getRhs())
1068 return prev.getRhs();
1069 }
1070 /// xor(a, xor(x, a)) -> x
1071 /// xor(a, xor(a, x)) -> x
1072 if (arith::XOrIOp prev = getRhs().getDefiningOp<arith::XOrIOp>()) {
1073 if (prev.getRhs() == getLhs())
1074 return prev.getLhs();
1075 if (prev.getLhs() == getLhs())
1076 return prev.getRhs();
1077 }
1078
1080 adaptor.getOperands(),
1081 [](APInt a, const APInt &b) { return std::move(a) ^ b; });
1082}
1083
1084void arith::XOrIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1085 MLIRContext *context) {
1086 patterns.add<XOrINotCmpI, XOrIOfExtUI, XOrIOfExtSI>(context);
1087}
1088
1089//===----------------------------------------------------------------------===//
1090// NegFOp
1091//===----------------------------------------------------------------------===//
1092
1093OpFoldResult arith::NegFOp::fold(FoldAdaptor adaptor) {
1094 /// negf(negf(x)) -> x
1095 if (auto op = this->getOperand().getDefiningOp<arith::NegFOp>())
1096 return op.getOperand();
1097 return constFoldUnaryOp<FloatAttr>(adaptor.getOperands(),
1098 [](const APFloat &a) { return -a; });
1099}
1100
1101//===----------------------------------------------------------------------===//
1102// AddFOp
1103//===----------------------------------------------------------------------===//
1104
1105OpFoldResult arith::AddFOp::fold(FoldAdaptor adaptor) {
1106 // addf(x, -0) -> x
1107 if (matchPattern(adaptor.getRhs(), m_NegZeroFloat()))
1108 return getLhs();
1109
1111 adaptor.getOperands(),
1112 [](const APFloat &a, const APFloat &b) { return a + b; });
1113}
1114
1115//===----------------------------------------------------------------------===//
1116// SubFOp
1117//===----------------------------------------------------------------------===//
1118
1119OpFoldResult arith::SubFOp::fold(FoldAdaptor adaptor) {
1120 // subf(x, +0) -> x
1121 if (matchPattern(adaptor.getRhs(), m_PosZeroFloat()))
1122 return getLhs();
1123
1125 adaptor.getOperands(),
1126 [](const APFloat &a, const APFloat &b) { return a - b; });
1127}
1128
1129//===----------------------------------------------------------------------===//
1130// MaximumFOp
1131//===----------------------------------------------------------------------===//
1132
1133OpFoldResult arith::MaximumFOp::fold(FoldAdaptor adaptor) {
1134 // maximumf(x,x) -> x
1135 if (getLhs() == getRhs())
1136 return getRhs();
1137
1138 // maximumf(x, -inf) -> x
1139 if (matchPattern(adaptor.getRhs(), m_NegInfFloat()))
1140 return getLhs();
1141
1143 adaptor.getOperands(),
1144 [](const APFloat &a, const APFloat &b) { return llvm::maximum(a, b); });
1145}
1146
1147//===----------------------------------------------------------------------===//
1148// MaxNumFOp
1149//===----------------------------------------------------------------------===//
1150
1151OpFoldResult arith::MaxNumFOp::fold(FoldAdaptor adaptor) {
1152 // maxnumf(x,x) -> x
1153 if (getLhs() == getRhs())
1154 return getRhs();
1155
1156 // maxnumf(x, NaN) -> x
1157 if (matchPattern(adaptor.getRhs(), m_NaNFloat()))
1158 return getLhs();
1159
1160 return constFoldBinaryOp<FloatAttr>(adaptor.getOperands(), llvm::maxnum);
1161}
1162
1163//===----------------------------------------------------------------------===//
1164// MaxSIOp
1165//===----------------------------------------------------------------------===//
1166
1167OpFoldResult MaxSIOp::fold(FoldAdaptor adaptor) {
1168 // maxsi(x,x) -> x
1169 if (getLhs() == getRhs())
1170 return getRhs();
1171
1172 if (APInt intValue;
1173 matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue))) {
1174 // maxsi(x,MAX_INT) -> MAX_INT
1175 if (intValue.isMaxSignedValue())
1176 return getRhs();
1177 // maxsi(x, MIN_INT) -> x
1178 if (intValue.isMinSignedValue())
1179 return getLhs();
1180 }
1181
1182 return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1183 [](const APInt &a, const APInt &b) {
1184 return llvm::APIntOps::smax(a, b);
1185 });
1186}
1187
1188//===----------------------------------------------------------------------===//
1189// MaxUIOp
1190//===----------------------------------------------------------------------===//
1191
1192OpFoldResult MaxUIOp::fold(FoldAdaptor adaptor) {
1193 // maxui(x,x) -> x
1194 if (getLhs() == getRhs())
1195 return getRhs();
1196
1197 if (APInt intValue;
1198 matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue))) {
1199 // maxui(x,MAX_INT) -> MAX_INT
1200 if (intValue.isMaxValue())
1201 return getRhs();
1202 // maxui(x, MIN_INT) -> x
1203 if (intValue.isMinValue())
1204 return getLhs();
1205 }
1206
1207 return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1208 [](const APInt &a, const APInt &b) {
1209 return llvm::APIntOps::umax(a, b);
1210 });
1211}
1212
1213//===----------------------------------------------------------------------===//
1214// MinimumFOp
1215//===----------------------------------------------------------------------===//
1216
1217OpFoldResult arith::MinimumFOp::fold(FoldAdaptor adaptor) {
1218 // minimumf(x,x) -> x
1219 if (getLhs() == getRhs())
1220 return getRhs();
1221
1222 // minimumf(x, +inf) -> x
1223 if (matchPattern(adaptor.getRhs(), m_PosInfFloat()))
1224 return getLhs();
1225
1227 adaptor.getOperands(),
1228 [](const APFloat &a, const APFloat &b) { return llvm::minimum(a, b); });
1229}
1230
1231//===----------------------------------------------------------------------===//
1232// MinNumFOp
1233//===----------------------------------------------------------------------===//
1234
1235OpFoldResult arith::MinNumFOp::fold(FoldAdaptor adaptor) {
1236 // minnumf(x,x) -> x
1237 if (getLhs() == getRhs())
1238 return getRhs();
1239
1240 // minnumf(x, NaN) -> x
1241 if (matchPattern(adaptor.getRhs(), m_NaNFloat()))
1242 return getLhs();
1243
1245 adaptor.getOperands(),
1246 [](const APFloat &a, const APFloat &b) { return llvm::minnum(a, b); });
1247}
1248
1249//===----------------------------------------------------------------------===//
1250// MinSIOp
1251//===----------------------------------------------------------------------===//
1252
1253OpFoldResult MinSIOp::fold(FoldAdaptor adaptor) {
1254 // minsi(x,x) -> x
1255 if (getLhs() == getRhs())
1256 return getRhs();
1257
1258 if (APInt intValue;
1259 matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue))) {
1260 // minsi(x,MIN_INT) -> MIN_INT
1261 if (intValue.isMinSignedValue())
1262 return getRhs();
1263 // minsi(x, MAX_INT) -> x
1264 if (intValue.isMaxSignedValue())
1265 return getLhs();
1266 }
1267
1268 return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1269 [](const APInt &a, const APInt &b) {
1270 return llvm::APIntOps::smin(a, b);
1271 });
1272}
1273
1274//===----------------------------------------------------------------------===//
1275// MinUIOp
1276//===----------------------------------------------------------------------===//
1277
1278OpFoldResult MinUIOp::fold(FoldAdaptor adaptor) {
1279 // minui(x,x) -> x
1280 if (getLhs() == getRhs())
1281 return getRhs();
1282
1283 if (APInt intValue;
1284 matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue))) {
1285 // minui(x,MIN_INT) -> MIN_INT
1286 if (intValue.isMinValue())
1287 return getRhs();
1288 // minui(x, MAX_INT) -> x
1289 if (intValue.isMaxValue())
1290 return getLhs();
1291 }
1292
1293 return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1294 [](const APInt &a, const APInt &b) {
1295 return llvm::APIntOps::umin(a, b);
1296 });
1297}
1298
1299//===----------------------------------------------------------------------===//
1300// MulFOp
1301//===----------------------------------------------------------------------===//
1302
1303OpFoldResult arith::MulFOp::fold(FoldAdaptor adaptor) {
1304 // mulf(x, 1) -> x
1305 if (matchPattern(adaptor.getRhs(), m_OneFloat()))
1306 return getLhs();
1307
1308 if (arith::bitEnumContainsAll(getFastmath(), arith::FastMathFlags::nnan |
1309 arith::FastMathFlags::nsz)) {
1310 // mulf(x, 0) -> 0
1311 if (matchPattern(adaptor.getRhs(), m_AnyZeroFloat()))
1312 return getRhs();
1313 }
1314
1316 adaptor.getOperands(),
1317 [](const APFloat &a, const APFloat &b) { return a * b; });
1318}
1319
1320void arith::MulFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1321 MLIRContext *context) {
1322 patterns.add<MulFOfNegF>(context);
1323}
1324
1325//===----------------------------------------------------------------------===//
1326// DivFOp
1327//===----------------------------------------------------------------------===//
1328
1329OpFoldResult arith::DivFOp::fold(FoldAdaptor adaptor) {
1330 // divf(x, 1) -> x
1331 if (matchPattern(adaptor.getRhs(), m_OneFloat()))
1332 return getLhs();
1333
1335 adaptor.getOperands(),
1336 [](const APFloat &a, const APFloat &b) { return a / b; });
1337}
1338
1339void arith::DivFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1340 MLIRContext *context) {
1341 patterns.add<DivFOfNegF>(context);
1342}
1343
1344//===----------------------------------------------------------------------===//
1345// RemFOp
1346//===----------------------------------------------------------------------===//
1347
1348OpFoldResult arith::RemFOp::fold(FoldAdaptor adaptor) {
1349 return constFoldBinaryOp<FloatAttr>(adaptor.getOperands(),
1350 [](const APFloat &a, const APFloat &b) {
1351 APFloat result(a);
1352 // APFloat::mod() offers the remainder
1353 // behavior we want, i.e. the result has
1354 // the sign of LHS operand.
1355 (void)result.mod(b);
1356 return result;
1357 });
1358}
1359
1360//===----------------------------------------------------------------------===//
1361// Utility functions for verifying cast ops
1362//===----------------------------------------------------------------------===//
1363
1364template <typename... Types>
1365using type_list = std::tuple<Types...> *;
1366
1367/// Returns a non-null type only if the provided type is one of the allowed
1368/// types or one of the allowed shaped types of the allowed types. Returns the
1369/// element type if a valid shaped type is provided.
1370template <typename... ShapedTypes, typename... ElementTypes>
1373 if (llvm::isa<ShapedType>(type) && !llvm::isa<ShapedTypes...>(type))
1374 return {};
1375
1376 auto underlyingType = getElementTypeOrSelf(type);
1377 if (!llvm::isa<ElementTypes...>(underlyingType))
1378 return {};
1379
1380 return underlyingType;
1381}
1382
1383/// Get allowed underlying types for vectors and tensors.
1384template <typename... ElementTypes>
1389
1390/// Get allowed underlying types for vectors, tensors, and memrefs.
1391template <typename... ElementTypes>
1397
1398/// Return false if both types are ranked tensor with mismatching encoding.
1399static bool hasSameEncoding(Type typeA, Type typeB) {
1400 auto rankedTensorA = dyn_cast<RankedTensorType>(typeA);
1401 auto rankedTensorB = dyn_cast<RankedTensorType>(typeB);
1402 if (!rankedTensorA || !rankedTensorB)
1403 return true;
1404 return rankedTensorA.getEncoding() == rankedTensorB.getEncoding();
1405}
1406
1408 if (inputs.size() != 1 || outputs.size() != 1)
1409 return false;
1410 if (!hasSameEncoding(inputs.front(), outputs.front()))
1411 return false;
1412 return succeeded(verifyCompatibleShapes(inputs.front(), outputs.front()));
1413}
1414
1415//===----------------------------------------------------------------------===//
1416// Verifiers for integer and floating point extension/truncation ops
1417//===----------------------------------------------------------------------===//
1418
1419// Extend ops can only extend to a wider type.
1420template <typename ValType, typename Op>
1421static LogicalResult verifyExtOp(Op op) {
1422 Type srcType = getElementTypeOrSelf(op.getIn().getType());
1423 Type dstType = getElementTypeOrSelf(op.getType());
1424
1425 if (llvm::cast<ValType>(srcType).getWidth() >=
1426 llvm::cast<ValType>(dstType).getWidth())
1427 return op.emitError("result type ")
1428 << dstType << " must be wider than operand type " << srcType;
1429
1430 return success();
1431}
1432
1433// Truncate ops can only truncate to a shorter type.
1434template <typename ValType, typename Op>
1435static LogicalResult verifyTruncateOp(Op op) {
1436 Type srcType = getElementTypeOrSelf(op.getIn().getType());
1437 Type dstType = getElementTypeOrSelf(op.getType());
1438
1439 if (llvm::cast<ValType>(srcType).getWidth() <=
1440 llvm::cast<ValType>(dstType).getWidth())
1441 return op.emitError("result type ")
1442 << dstType << " must be shorter than operand type " << srcType;
1443
1444 return success();
1445}
1446
1447/// Validate a cast that changes the width of a type.
1448template <template <typename> class WidthComparator, typename... ElementTypes>
1449static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs) {
1450 if (!areValidCastInputsAndOutputs(inputs, outputs))
1451 return false;
1452
1453 auto srcType = getTypeIfLike<ElementTypes...>(inputs.front());
1454 auto dstType = getTypeIfLike<ElementTypes...>(outputs.front());
1455 if (!srcType || !dstType)
1456 return false;
1457
1458 return WidthComparator<unsigned>()(dstType.getIntOrFloatBitWidth(),
1459 srcType.getIntOrFloatBitWidth());
1460}
1461
1462/// Attempts to convert `sourceValue` to an APFloat value with
1463/// `targetSemantics` and `roundingMode`, without any information loss.
1464static FailureOr<APFloat> convertFloatValue(
1465 APFloat sourceValue, const llvm::fltSemantics &targetSemantics,
1466 llvm::RoundingMode roundingMode = llvm::RoundingMode::NearestTiesToEven) {
1467 // Reject special values that are not representable in the target type before
1468 // calling APFloat::convert, which would llvm_unreachable on them.
1469 using fltNonfiniteBehavior = llvm::fltNonfiniteBehavior;
1470 if (sourceValue.isInfinity() &&
1471 (targetSemantics.nonFiniteBehavior == fltNonfiniteBehavior::NanOnly ||
1472 targetSemantics.nonFiniteBehavior == fltNonfiniteBehavior::FiniteOnly))
1473 return failure();
1474 if (sourceValue.isNaN() &&
1475 targetSemantics.nonFiniteBehavior == fltNonfiniteBehavior::FiniteOnly)
1476 return failure();
1477
1478 bool losesInfo = false;
1479 auto status = sourceValue.convert(targetSemantics, roundingMode, &losesInfo);
1480 if (losesInfo || status != APFloat::opOK)
1481 return failure();
1482
1483 return sourceValue;
1484}
1485
1486//===----------------------------------------------------------------------===//
1487// ExtUIOp
1488//===----------------------------------------------------------------------===//
1489
1490OpFoldResult arith::ExtUIOp::fold(FoldAdaptor adaptor) {
1491 if (auto lhs = getIn().getDefiningOp<ExtUIOp>()) {
1492 getInMutable().assign(lhs.getIn());
1493 return getResult();
1494 }
1495
1496 Type resType = getElementTypeOrSelf(getType());
1497 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1499 adaptor.getOperands(), getType(),
1500 [bitWidth](const APInt &a, bool &castStatus) {
1501 return a.zext(bitWidth);
1502 });
1503}
1504
1505bool arith::ExtUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1507}
1508
1509LogicalResult arith::ExtUIOp::verify() {
1510 return verifyExtOp<IntegerType>(*this);
1511}
1512
1513//===----------------------------------------------------------------------===//
1514// ExtSIOp
1515//===----------------------------------------------------------------------===//
1516
1517OpFoldResult arith::ExtSIOp::fold(FoldAdaptor adaptor) {
1518 if (auto lhs = getIn().getDefiningOp<ExtSIOp>()) {
1519 getInMutable().assign(lhs.getIn());
1520 return getResult();
1521 }
1522
1523 Type resType = getElementTypeOrSelf(getType());
1524 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1526 adaptor.getOperands(), getType(),
1527 [bitWidth](const APInt &a, bool &castStatus) {
1528 return a.sext(bitWidth);
1529 });
1530}
1531
1532bool arith::ExtSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1534}
1535
1536void arith::ExtSIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1537 MLIRContext *context) {
1538 patterns.add<ExtSIOfExtUI>(context);
1539}
1540
1541LogicalResult arith::ExtSIOp::verify() {
1542 return verifyExtOp<IntegerType>(*this);
1543}
1544
1545//===----------------------------------------------------------------------===//
1546// ExtFOp
1547//===----------------------------------------------------------------------===//
1548
1549/// Fold extension of float constants when there is no information loss due the
1550/// difference in fp semantics.
1551OpFoldResult arith::ExtFOp::fold(FoldAdaptor adaptor) {
1552 if (auto truncFOp = getOperand().getDefiningOp<TruncFOp>()) {
1553 if (truncFOp.getOperand().getType() == getType()) {
1554 arith::FastMathFlags truncFMF =
1555 truncFOp.getFastmath().value_or(arith::FastMathFlags::none);
1556 bool isTruncContract =
1557 bitEnumContainsAll(truncFMF, arith::FastMathFlags::contract);
1558 arith::FastMathFlags extFMF =
1559 getFastmath().value_or(arith::FastMathFlags::none);
1560 bool isExtContract =
1561 bitEnumContainsAll(extFMF, arith::FastMathFlags::contract);
1562 if (isTruncContract && isExtContract) {
1563 return truncFOp.getOperand();
1564 }
1565 }
1566 }
1567
1568 auto resElemType = cast<FloatType>(getElementTypeOrSelf(getType()));
1569 const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
1571 adaptor.getOperands(), getType(),
1572 [&targetSemantics](const APFloat &a, bool &castStatus) {
1573 FailureOr<APFloat> result = convertFloatValue(a, targetSemantics);
1574 if (failed(result)) {
1575 castStatus = false;
1576 return a;
1577 }
1578 return *result;
1579 });
1580}
1581
1582bool arith::ExtFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1583 return checkWidthChangeCast<std::greater, FloatType>(inputs, outputs);
1584}
1585
1586LogicalResult arith::ExtFOp::verify() { return verifyExtOp<FloatType>(*this); }
1587
1588//===----------------------------------------------------------------------===//
1589// ScalingExtFOp
1590//===----------------------------------------------------------------------===//
1591
1592bool arith::ScalingExtFOp::areCastCompatible(TypeRange inputs,
1593 TypeRange outputs) {
1594 return checkWidthChangeCast<std::greater, FloatType>(inputs.front(), outputs);
1595}
1596
1597LogicalResult arith::ScalingExtFOp::verify() {
1598 return verifyExtOp<FloatType>(*this);
1599}
1600
1601//===----------------------------------------------------------------------===//
1602// TruncIOp
1603//===----------------------------------------------------------------------===//
1604
1605OpFoldResult arith::TruncIOp::fold(FoldAdaptor adaptor) {
1606 if (matchPattern(getOperand(), m_Op<arith::ExtUIOp>()) ||
1607 matchPattern(getOperand(), m_Op<arith::ExtSIOp>())) {
1608 Value src = getOperand().getDefiningOp()->getOperand(0);
1609 Type srcType = getElementTypeOrSelf(src.getType());
1610 Type dstType = getElementTypeOrSelf(getType());
1611 // trunci(zexti(a)) -> trunci(a)
1612 // trunci(sexti(a)) -> trunci(a)
1613 if (llvm::cast<IntegerType>(srcType).getWidth() >
1614 llvm::cast<IntegerType>(dstType).getWidth()) {
1615 setOperand(src);
1616 return getResult();
1617 }
1618
1619 // trunci(zexti(a)) -> a
1620 // trunci(sexti(a)) -> a
1621 if (srcType == dstType)
1622 return src;
1623 }
1624
1625 // trunci(trunci(a)) -> trunci(a))
1626 if (matchPattern(getOperand(), m_Op<arith::TruncIOp>())) {
1627 setOperand(getOperand().getDefiningOp()->getOperand(0));
1628 return getResult();
1629 }
1630
1631 Type resType = getElementTypeOrSelf(getType());
1632 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1634 adaptor.getOperands(), getType(),
1635 [bitWidth](const APInt &a, bool &castStatus) {
1636 return a.trunc(bitWidth);
1637 });
1638}
1639
1640bool arith::TruncIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1641 return checkWidthChangeCast<std::less, IntegerType>(inputs, outputs);
1642}
1643
1644void arith::TruncIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1645 MLIRContext *context) {
1646 patterns
1647 .add<TruncIExtSIToExtSI, TruncIExtUIToExtUI, TruncIShrSIToTrunciShrUI>(
1648 context);
1649}
1650
1651LogicalResult arith::TruncIOp::verify() {
1652 return verifyTruncateOp<IntegerType>(*this);
1653}
1654
1655//===----------------------------------------------------------------------===//
1656// TruncFOp
1657//===----------------------------------------------------------------------===//
1658
1659/// Perform safe const propagation for truncf, i.e., only propagate if FP value
1660/// can be represented without precision loss.
1661OpFoldResult arith::TruncFOp::fold(FoldAdaptor adaptor) {
1662 auto resElemType = cast<FloatType>(getElementTypeOrSelf(getType()));
1663 if (auto extOp = getOperand().getDefiningOp<arith::ExtFOp>()) {
1664 Value src = extOp.getIn();
1665 auto srcType = cast<FloatType>(getElementTypeOrSelf(src.getType()));
1666 auto intermediateType =
1667 cast<FloatType>(getElementTypeOrSelf(extOp.getType()));
1668 // Check if the srcType is representable in the intermediateType.
1669 if (llvm::APFloatBase::isRepresentableBy(
1670 srcType.getFloatSemantics(),
1671 intermediateType.getFloatSemantics())) {
1672 // truncf(extf(a)) -> truncf(a)
1673 if (srcType.getWidth() > resElemType.getWidth()) {
1674 setOperand(src);
1675 return getResult();
1676 }
1677
1678 // truncf(extf(a)) -> a
1679 if (srcType == resElemType)
1680 return src;
1681 }
1682 }
1683
1684 const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
1686 adaptor.getOperands(), getType(),
1687 [this, &targetSemantics](const APFloat &a, bool &castStatus) {
1688 RoundingMode roundingMode =
1689 getRoundingmode().value_or(RoundingMode::to_nearest_even);
1690 llvm::RoundingMode llvmRoundingMode =
1692 FailureOr<APFloat> result =
1693 convertFloatValue(a, targetSemantics, llvmRoundingMode);
1694 if (failed(result)) {
1695 castStatus = false;
1696 return a;
1697 }
1698 return *result;
1699 });
1700}
1701
1702void arith::TruncFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1703 MLIRContext *context) {
1704 patterns.add<TruncFSIToFPToSIToFP, TruncFUIToFPToUIToFP>(context);
1705}
1706
1707bool arith::TruncFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1708 return checkWidthChangeCast<std::less, FloatType>(inputs, outputs);
1709}
1710
1711LogicalResult arith::TruncFOp::verify() {
1712 return verifyTruncateOp<FloatType>(*this);
1713}
1714
1715//===----------------------------------------------------------------------===//
1716// ConvertFOp
1717//===----------------------------------------------------------------------===//
1718
1719OpFoldResult arith::ConvertFOp::fold(FoldAdaptor adaptor) {
1720 auto resElemType = cast<FloatType>(getElementTypeOrSelf(getType()));
1721 const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
1723 adaptor.getOperands(), getType(),
1724 [this, &targetSemantics](const APFloat &a, bool &castStatus) {
1725 RoundingMode roundingMode =
1726 getRoundingmode().value_or(RoundingMode::to_nearest_even);
1727 llvm::RoundingMode llvmRoundingMode =
1729 FailureOr<APFloat> result =
1730 convertFloatValue(a, targetSemantics, llvmRoundingMode);
1731 if (failed(result)) {
1732 castStatus = false;
1733 return a;
1734 }
1735 return *result;
1736 });
1737}
1738
1739bool arith::ConvertFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1740 if (!areValidCastInputsAndOutputs(inputs, outputs))
1741 return false;
1742 auto srcType = getTypeIfLike<FloatType>(inputs.front());
1743 auto dstType = getTypeIfLike<FloatType>(outputs.front());
1744 if (!srcType || !dstType)
1745 return false;
1746 return srcType != dstType &&
1747 srcType.getIntOrFloatBitWidth() == dstType.getIntOrFloatBitWidth();
1748}
1749
1750LogicalResult arith::ConvertFOp::verify() {
1751 auto srcType = cast<FloatType>(getElementTypeOrSelf(getIn().getType()));
1752 auto dstType = cast<FloatType>(getElementTypeOrSelf(getType()));
1753 if (srcType == dstType)
1754 return emitError("result element type ")
1755 << dstType << " must be different from operand element type "
1756 << srcType;
1757 if (srcType.getWidth() != dstType.getWidth())
1758 return emitError("result element type ")
1759 << dstType << " must have the same bitwidth as operand element type "
1760 << srcType;
1761 return success();
1762}
1763
1764//===----------------------------------------------------------------------===//
1765// ScalingTruncFOp
1766//===----------------------------------------------------------------------===//
1767
1768bool arith::ScalingTruncFOp::areCastCompatible(TypeRange inputs,
1769 TypeRange outputs) {
1770 return checkWidthChangeCast<std::less, FloatType>(inputs.front(), outputs);
1771}
1772
1773LogicalResult arith::ScalingTruncFOp::verify() {
1774 return verifyTruncateOp<FloatType>(*this);
1775}
1776
1777//===----------------------------------------------------------------------===//
1778// AndIOp
1779//===----------------------------------------------------------------------===//
1780
1781void arith::AndIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1782 MLIRContext *context) {
1783 patterns.add<AndOfExtUI, AndOfExtSI>(context);
1784}
1785
1786//===----------------------------------------------------------------------===//
1787// OrIOp
1788//===----------------------------------------------------------------------===//
1789
1790void arith::OrIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1791 MLIRContext *context) {
1792 patterns.add<OrOfExtUI, OrOfExtSI>(context);
1793}
1794
1795//===----------------------------------------------------------------------===//
1796// Verifiers for casts between integers and floats.
1797//===----------------------------------------------------------------------===//
1798
1799template <typename From, typename To>
1800static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs) {
1801 if (!areValidCastInputsAndOutputs(inputs, outputs))
1802 return false;
1803
1804 auto srcType = getTypeIfLike<From>(inputs.front());
1805 auto dstType = getTypeIfLike<To>(outputs.back());
1806
1807 return srcType && dstType;
1808}
1809
1810//===----------------------------------------------------------------------===//
1811// UIToFPOp
1812//===----------------------------------------------------------------------===//
1813
1814bool arith::UIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1815 return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
1816}
1817
1818OpFoldResult arith::UIToFPOp::fold(FoldAdaptor adaptor) {
1819 Type resEleType = getElementTypeOrSelf(getType());
1821 adaptor.getOperands(), getType(),
1822 [&resEleType](const APInt &a, bool &castStatus) {
1823 FloatType floatTy = llvm::cast<FloatType>(resEleType);
1824 APFloat apf(floatTy.getFloatSemantics(),
1825 APInt::getZero(floatTy.getWidth()));
1826 apf.convertFromAPInt(a, /*IsSigned=*/false,
1827 APFloat::rmNearestTiesToEven);
1828 return apf;
1829 });
1830}
1831
1832void arith::UIToFPOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1833 MLIRContext *context) {
1834 patterns.add<UIToFPOfExtUI>(context);
1835}
1836
1837//===----------------------------------------------------------------------===//
1838// SIToFPOp
1839//===----------------------------------------------------------------------===//
1840
1841bool arith::SIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1842 return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
1843}
1844
1845OpFoldResult arith::SIToFPOp::fold(FoldAdaptor adaptor) {
1846 Type resEleType = getElementTypeOrSelf(getType());
1848 adaptor.getOperands(), getType(),
1849 [&resEleType](const APInt &a, bool &castStatus) {
1850 FloatType floatTy = llvm::cast<FloatType>(resEleType);
1851 APFloat apf(floatTy.getFloatSemantics(),
1852 APInt::getZero(floatTy.getWidth()));
1853 apf.convertFromAPInt(a, /*IsSigned=*/true,
1854 APFloat::rmNearestTiesToEven);
1855 return apf;
1856 });
1857}
1858
1859void arith::SIToFPOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1860 MLIRContext *context) {
1861 patterns.add<SIToFPOfExtSI, SIToFPOfExtUI>(context);
1862}
1863
1864//===----------------------------------------------------------------------===//
1865// FPToUIOp
1866//===----------------------------------------------------------------------===//
1867
1868bool arith::FPToUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1869 return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
1870}
1871
1872OpFoldResult arith::FPToUIOp::fold(FoldAdaptor adaptor) {
1873 Type resType = getElementTypeOrSelf(getType());
1874 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1876 adaptor.getOperands(), getType(),
1877 [&bitWidth](const APFloat &a, bool &castStatus) {
1878 bool ignored;
1879 APSInt api(bitWidth, /*isUnsigned=*/true);
1880 castStatus = APFloat::opInvalidOp !=
1881 a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
1882 return api;
1883 });
1884}
1885
1886//===----------------------------------------------------------------------===//
1887// FPToSIOp
1888//===----------------------------------------------------------------------===//
1889
1890bool arith::FPToSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1891 return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
1892}
1893
1894OpFoldResult arith::FPToSIOp::fold(FoldAdaptor adaptor) {
1895 Type resType = getElementTypeOrSelf(getType());
1896 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1898 adaptor.getOperands(), getType(),
1899 [&bitWidth](const APFloat &a, bool &castStatus) {
1900 bool ignored;
1901 APSInt api(bitWidth, /*isUnsigned=*/false);
1902 castStatus = APFloat::opInvalidOp !=
1903 a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
1904 return api;
1905 });
1906}
1907
1908//===----------------------------------------------------------------------===//
1909// IndexCastOp
1910//===----------------------------------------------------------------------===//
1911
1912static bool areIndexCastCompatible(TypeRange inputs, TypeRange outputs) {
1913 if (!areValidCastInputsAndOutputs(inputs, outputs))
1914 return false;
1915
1916 auto srcType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(inputs.front());
1917 auto dstType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(outputs.front());
1918 if (!srcType || !dstType)
1919 return false;
1920
1921 return (srcType.isIndex() && dstType.isSignlessInteger()) ||
1922 (srcType.isSignlessInteger() && dstType.isIndex());
1923}
1924
1925bool arith::IndexCastOp::areCastCompatible(TypeRange inputs,
1926 TypeRange outputs) {
1927 return areIndexCastCompatible(inputs, outputs);
1928}
1929
1930OpFoldResult arith::IndexCastOp::fold(FoldAdaptor adaptor) {
1931 // index_cast(constant) -> constant
1932 unsigned resultBitwidth = 64; // Default for index integer attributes.
1933 if (auto intTy = dyn_cast<IntegerType>(getElementTypeOrSelf(getType())))
1934 resultBitwidth = intTy.getWidth();
1935
1937 adaptor.getOperands(), getType(),
1938 [resultBitwidth](const APInt &a, bool & /*castStatus*/) {
1939 return a.sextOrTrunc(resultBitwidth);
1940 });
1941}
1942
1943void arith::IndexCastOp::getCanonicalizationPatterns(
1944 RewritePatternSet &patterns, MLIRContext *context) {
1945 patterns.add<IndexCastOfIndexCast, IndexCastOfExtSI>(context);
1946}
1947
1948//===----------------------------------------------------------------------===//
1949// IndexCastUIOp
1950//===----------------------------------------------------------------------===//
1951
1952bool arith::IndexCastUIOp::areCastCompatible(TypeRange inputs,
1953 TypeRange outputs) {
1954 return areIndexCastCompatible(inputs, outputs);
1955}
1956
1957OpFoldResult arith::IndexCastUIOp::fold(FoldAdaptor adaptor) {
1958 // index_castui(constant) -> constant
1959 unsigned resultBitwidth = 64; // Default for index integer attributes.
1960 if (auto intTy = dyn_cast<IntegerType>(getElementTypeOrSelf(getType())))
1961 resultBitwidth = intTy.getWidth();
1962
1964 adaptor.getOperands(), getType(),
1965 [resultBitwidth](const APInt &a, bool & /*castStatus*/) {
1966 return a.zextOrTrunc(resultBitwidth);
1967 });
1968}
1969
1970void arith::IndexCastUIOp::getCanonicalizationPatterns(
1971 RewritePatternSet &patterns, MLIRContext *context) {
1972 patterns.add<IndexCastUIOfIndexCastUI, IndexCastUIOfExtUI>(context);
1973}
1974
1975//===----------------------------------------------------------------------===//
1976// BitcastOp
1977//===----------------------------------------------------------------------===//
1978
1979bool arith::BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1980 if (!areValidCastInputsAndOutputs(inputs, outputs))
1981 return false;
1982
1983 auto srcType = getTypeIfLikeOrMemRef<IntegerType, FloatType>(inputs.front());
1984 auto dstType = getTypeIfLikeOrMemRef<IntegerType, FloatType>(outputs.front());
1985 if (!srcType || !dstType)
1986 return false;
1987
1988 return srcType.getIntOrFloatBitWidth() == dstType.getIntOrFloatBitWidth();
1989}
1990
1991OpFoldResult arith::BitcastOp::fold(FoldAdaptor adaptor) {
1992 auto resType = getType();
1993 auto operand = adaptor.getIn();
1994 if (!operand)
1995 return {};
1996
1997 /// Bitcast dense elements.
1998 if (auto denseAttr = dyn_cast_or_null<DenseElementsAttr>(operand))
1999 return denseAttr.bitcast(llvm::cast<ShapedType>(resType).getElementType());
2000 /// Other shaped types unhandled.
2001 if (llvm::isa<ShapedType>(resType))
2002 return {};
2003
2004 /// Bitcast poison.
2005 if (matchPattern(operand, ub::m_Poison()))
2006 return ub::PoisonAttr::get(getContext());
2007
2008 /// Bitcast integer or float to integer or float.
2009 APInt bits = llvm::isa<FloatAttr>(operand)
2010 ? llvm::cast<FloatAttr>(operand).getValue().bitcastToAPInt()
2011 : llvm::cast<IntegerAttr>(operand).getValue();
2012 assert(resType.getIntOrFloatBitWidth() == bits.getBitWidth() &&
2013 "trying to fold on broken IR: operands have incompatible types");
2014
2015 if (auto resFloatType = dyn_cast<FloatType>(resType))
2016 return FloatAttr::get(resType,
2017 APFloat(resFloatType.getFloatSemantics(), bits));
2018 return IntegerAttr::get(resType, bits);
2019}
2020
2021void arith::BitcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
2022 MLIRContext *context) {
2023 patterns.add<BitcastOfBitcast>(context);
2024}
2025
2026//===----------------------------------------------------------------------===//
2027// CmpIOp
2028//===----------------------------------------------------------------------===//
2029
2030/// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer
2031/// comparison predicates.
2032bool mlir::arith::applyCmpPredicate(arith::CmpIPredicate predicate,
2033 const APInt &lhs, const APInt &rhs) {
2034 switch (predicate) {
2035 case arith::CmpIPredicate::eq:
2036 return lhs.eq(rhs);
2037 case arith::CmpIPredicate::ne:
2038 return lhs.ne(rhs);
2039 case arith::CmpIPredicate::slt:
2040 return lhs.slt(rhs);
2041 case arith::CmpIPredicate::sle:
2042 return lhs.sle(rhs);
2043 case arith::CmpIPredicate::sgt:
2044 return lhs.sgt(rhs);
2045 case arith::CmpIPredicate::sge:
2046 return lhs.sge(rhs);
2047 case arith::CmpIPredicate::ult:
2048 return lhs.ult(rhs);
2049 case arith::CmpIPredicate::ule:
2050 return lhs.ule(rhs);
2051 case arith::CmpIPredicate::ugt:
2052 return lhs.ugt(rhs);
2053 case arith::CmpIPredicate::uge:
2054 return lhs.uge(rhs);
2055 }
2056 llvm_unreachable("unknown cmpi predicate kind");
2057}
2058
2059/// Returns true if the predicate is true for two equal operands.
2060static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate) {
2061 switch (predicate) {
2062 case arith::CmpIPredicate::eq:
2063 case arith::CmpIPredicate::sle:
2064 case arith::CmpIPredicate::sge:
2065 case arith::CmpIPredicate::ule:
2066 case arith::CmpIPredicate::uge:
2067 return true;
2068 case arith::CmpIPredicate::ne:
2069 case arith::CmpIPredicate::slt:
2070 case arith::CmpIPredicate::sgt:
2071 case arith::CmpIPredicate::ult:
2072 case arith::CmpIPredicate::ugt:
2073 return false;
2074 }
2075 llvm_unreachable("unknown cmpi predicate kind");
2076}
2077
2078static std::optional<int64_t> getIntegerWidth(Type t) {
2079 if (auto intType = dyn_cast<IntegerType>(t)) {
2080 return intType.getWidth();
2081 }
2082 if (auto vectorIntType = dyn_cast<VectorType>(t)) {
2083 return llvm::cast<IntegerType>(vectorIntType.getElementType()).getWidth();
2084 }
2085 return std::nullopt;
2086}
2087
2088OpFoldResult arith::CmpIOp::fold(FoldAdaptor adaptor) {
2089 // cmpi(pred, x, x)
2090 if (getLhs() == getRhs()) {
2091 auto val = applyCmpPredicateToEqualOperands(getPredicate());
2092 return getBoolAttribute(getType(), val);
2093 }
2094
2095 if (matchPattern(adaptor.getRhs(), m_Zero())) {
2096 if (auto extOp = getLhs().getDefiningOp<ExtSIOp>()) {
2097 // extsi(%x : i1 -> iN) != 0 -> %x
2098 std::optional<int64_t> integerWidth =
2099 getIntegerWidth(extOp.getOperand().getType());
2100 if (integerWidth && integerWidth.value() == 1 &&
2101 getPredicate() == arith::CmpIPredicate::ne)
2102 return extOp.getOperand();
2103 }
2104 if (auto extOp = getLhs().getDefiningOp<ExtUIOp>()) {
2105 // extui(%x : i1 -> iN) != 0 -> %x
2106 std::optional<int64_t> integerWidth =
2107 getIntegerWidth(extOp.getOperand().getType());
2108 if (integerWidth && integerWidth.value() == 1 &&
2109 getPredicate() == arith::CmpIPredicate::ne)
2110 return extOp.getOperand();
2111 }
2112
2113 // arith.cmpi ne, %val, %zero : i1 -> %val
2114 if (getElementTypeOrSelf(getLhs().getType()).isInteger(1) &&
2115 getPredicate() == arith::CmpIPredicate::ne)
2116 return getLhs();
2117 }
2118
2119 if (matchPattern(adaptor.getRhs(), m_One())) {
2120 // arith.cmpi eq, %val, %one : i1 -> %val
2121 if (getElementTypeOrSelf(getLhs().getType()).isInteger(1) &&
2122 getPredicate() == arith::CmpIPredicate::eq)
2123 return getLhs();
2124 }
2125
2126 // Move constant to the right side.
2127 if (adaptor.getLhs() && !adaptor.getRhs()) {
2128 // Do not use invertPredicate, as it will change eq to ne and vice versa.
2129 using Pred = CmpIPredicate;
2130 const std::pair<Pred, Pred> invPreds[] = {
2131 {Pred::slt, Pred::sgt}, {Pred::sgt, Pred::slt}, {Pred::sle, Pred::sge},
2132 {Pred::sge, Pred::sle}, {Pred::ult, Pred::ugt}, {Pred::ugt, Pred::ult},
2133 {Pred::ule, Pred::uge}, {Pred::uge, Pred::ule}, {Pred::eq, Pred::eq},
2134 {Pred::ne, Pred::ne},
2135 };
2136 Pred origPred = getPredicate();
2137 for (auto pred : invPreds) {
2138 if (origPred == pred.first) {
2139 setPredicate(pred.second);
2140 Value lhs = getLhs();
2141 Value rhs = getRhs();
2142 getLhsMutable().assign(rhs);
2143 getRhsMutable().assign(lhs);
2144 return getResult();
2145 }
2146 }
2147 llvm_unreachable("unknown cmpi predicate kind");
2148 }
2149
2150 // We are moving constants to the right side; So if lhs is constant rhs is
2151 // guaranteed to be a constant.
2152 if (auto lhs = dyn_cast_if_present<TypedAttr>(adaptor.getLhs())) {
2154 adaptor.getOperands(), getI1SameShape(lhs.getType()),
2155 [pred = getPredicate()](const APInt &lhs, const APInt &rhs) {
2156 return APInt(1,
2157 static_cast<int64_t>(applyCmpPredicate(pred, lhs, rhs)));
2158 });
2159 }
2160
2161 return {};
2162}
2163
2164void arith::CmpIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
2165 MLIRContext *context) {
2166 patterns.insert<CmpIExtSI, CmpIExtUI>(context);
2167}
2168
2169//===----------------------------------------------------------------------===//
2170// CmpFOp
2171//===----------------------------------------------------------------------===//
2172
2173/// Compute `lhs` `pred` `rhs`, where `pred` is one of the known floating point
2174/// comparison predicates.
2175bool mlir::arith::applyCmpPredicate(arith::CmpFPredicate predicate,
2176 const APFloat &lhs, const APFloat &rhs) {
2177 auto cmpResult = lhs.compare(rhs);
2178 switch (predicate) {
2179 case arith::CmpFPredicate::AlwaysFalse:
2180 return false;
2181 case arith::CmpFPredicate::OEQ:
2182 return cmpResult == APFloat::cmpEqual;
2183 case arith::CmpFPredicate::OGT:
2184 return cmpResult == APFloat::cmpGreaterThan;
2185 case arith::CmpFPredicate::OGE:
2186 return cmpResult == APFloat::cmpGreaterThan ||
2187 cmpResult == APFloat::cmpEqual;
2188 case arith::CmpFPredicate::OLT:
2189 return cmpResult == APFloat::cmpLessThan;
2190 case arith::CmpFPredicate::OLE:
2191 return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
2192 case arith::CmpFPredicate::ONE:
2193 return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual;
2194 case arith::CmpFPredicate::ORD:
2195 return cmpResult != APFloat::cmpUnordered;
2196 case arith::CmpFPredicate::UEQ:
2197 return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual;
2198 case arith::CmpFPredicate::UGT:
2199 return cmpResult == APFloat::cmpUnordered ||
2200 cmpResult == APFloat::cmpGreaterThan;
2201 case arith::CmpFPredicate::UGE:
2202 return cmpResult == APFloat::cmpUnordered ||
2203 cmpResult == APFloat::cmpGreaterThan ||
2204 cmpResult == APFloat::cmpEqual;
2205 case arith::CmpFPredicate::ULT:
2206 return cmpResult == APFloat::cmpUnordered ||
2207 cmpResult == APFloat::cmpLessThan;
2208 case arith::CmpFPredicate::ULE:
2209 return cmpResult == APFloat::cmpUnordered ||
2210 cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
2211 case arith::CmpFPredicate::UNE:
2212 return cmpResult != APFloat::cmpEqual;
2213 case arith::CmpFPredicate::UNO:
2214 return cmpResult == APFloat::cmpUnordered;
2215 case arith::CmpFPredicate::AlwaysTrue:
2216 return true;
2217 }
2218 llvm_unreachable("unknown cmpf predicate kind");
2219}
2220
2221OpFoldResult arith::CmpFOp::fold(FoldAdaptor adaptor) {
2222 auto lhs = dyn_cast_if_present<FloatAttr>(adaptor.getLhs());
2223 auto rhs = dyn_cast_if_present<FloatAttr>(adaptor.getRhs());
2224
2225 // If one operand is NaN, making them both NaN does not change the result.
2226 if (lhs && lhs.getValue().isNaN())
2227 rhs = lhs;
2228 if (rhs && rhs.getValue().isNaN())
2229 lhs = rhs;
2230
2231 if (!lhs || !rhs)
2232 return {};
2233
2234 auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
2235 return BoolAttr::get(getContext(), val);
2236}
2237
2238class CmpFIntToFPConst final : public OpRewritePattern<CmpFOp> {
2239public:
2240 using Base::Base;
2241
2242 static CmpIPredicate convertToIntegerPredicate(CmpFPredicate pred,
2243 bool isUnsigned) {
2244 using namespace arith;
2245 switch (pred) {
2246 case CmpFPredicate::UEQ:
2247 case CmpFPredicate::OEQ:
2248 return CmpIPredicate::eq;
2249 case CmpFPredicate::UGT:
2250 case CmpFPredicate::OGT:
2251 return isUnsigned ? CmpIPredicate::ugt : CmpIPredicate::sgt;
2252 case CmpFPredicate::UGE:
2253 case CmpFPredicate::OGE:
2254 return isUnsigned ? CmpIPredicate::uge : CmpIPredicate::sge;
2255 case CmpFPredicate::ULT:
2256 case CmpFPredicate::OLT:
2257 return isUnsigned ? CmpIPredicate::ult : CmpIPredicate::slt;
2258 case CmpFPredicate::ULE:
2259 case CmpFPredicate::OLE:
2260 return isUnsigned ? CmpIPredicate::ule : CmpIPredicate::sle;
2261 case CmpFPredicate::UNE:
2262 case CmpFPredicate::ONE:
2263 return CmpIPredicate::ne;
2264 default:
2265 llvm_unreachable("Unexpected predicate!");
2266 }
2267 }
2268
2269 LogicalResult matchAndRewrite(CmpFOp op,
2270 PatternRewriter &rewriter) const override {
2271 FloatAttr flt;
2272 if (!matchPattern(op.getRhs(), m_Constant(&flt)))
2273 return failure();
2274
2275 const APFloat &rhs = flt.getValue();
2276
2277 // Don't attempt to fold a nan.
2278 if (rhs.isNaN())
2279 return failure();
2280
2281 // Get the width of the mantissa. We don't want to hack on conversions that
2282 // might lose information from the integer, e.g. "i64 -> float"
2283 FloatType floatTy = llvm::cast<FloatType>(op.getRhs().getType());
2284 int mantissaWidth = floatTy.getFPMantissaWidth();
2285 if (mantissaWidth <= 0)
2286 return failure();
2287
2288 bool isUnsigned;
2289 Value intVal;
2290
2291 if (auto si = op.getLhs().getDefiningOp<SIToFPOp>()) {
2292 isUnsigned = false;
2293 intVal = si.getIn();
2294 } else if (auto ui = op.getLhs().getDefiningOp<UIToFPOp>()) {
2295 isUnsigned = true;
2296 intVal = ui.getIn();
2297 } else {
2298 return failure();
2299 }
2300
2301 // Check to see that the input is converted from an integer type that is
2302 // small enough that preserves all bits.
2303 auto intTy = llvm::cast<IntegerType>(intVal.getType());
2304 auto intWidth = intTy.getWidth();
2305
2306 // Number of bits representing values, as opposed to the sign
2307 auto valueBits = isUnsigned ? intWidth : (intWidth - 1);
2308
2309 // Following test does NOT adjust intWidth downwards for signed inputs,
2310 // because the most negative value still requires all the mantissa bits
2311 // to distinguish it from one less than that value.
2312 if ((int)intWidth > mantissaWidth) {
2313 // Conversion would lose accuracy. Check if loss can impact comparison.
2314 int exponent = ilogb(rhs);
2315 if (exponent == APFloat::IEK_Inf) {
2316 int maxExponent = ilogb(APFloat::getLargest(rhs.getSemantics()));
2317 if (maxExponent < (int)valueBits) {
2318 // Conversion could create infinity.
2319 return failure();
2320 }
2321 } else {
2322 // Note that if rhs is zero or NaN, then Exp is negative
2323 // and first condition is trivially false.
2324 if (mantissaWidth <= exponent && exponent <= (int)valueBits) {
2325 // Conversion could affect comparison.
2326 return failure();
2327 }
2328 }
2329 }
2330
2331 // Convert to equivalent cmpi predicate
2332 CmpIPredicate pred;
2333 switch (op.getPredicate()) {
2334 case CmpFPredicate::ORD:
2335 // Int to fp conversion doesn't create a nan (ord checks neither is a nan)
2336 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2337 /*width=*/1);
2338 return success();
2339 case CmpFPredicate::UNO:
2340 // Int to fp conversion doesn't create a nan (uno checks either is a nan)
2341 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2342 /*width=*/1);
2343 return success();
2344 default:
2345 pred = convertToIntegerPredicate(op.getPredicate(), isUnsigned);
2346 break;
2347 }
2348
2349 if (!isUnsigned) {
2350 // If the rhs value is > SignedMax, fold the comparison. This handles
2351 // +INF and large values.
2352 APFloat signedMax(rhs.getSemantics());
2353 signedMax.convertFromAPInt(APInt::getSignedMaxValue(intWidth), true,
2354 APFloat::rmNearestTiesToEven);
2355 if (signedMax < rhs) { // smax < 13123.0
2356 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::slt ||
2357 pred == CmpIPredicate::sle)
2358 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2359 /*width=*/1);
2360 else
2361 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2362 /*width=*/1);
2363 return success();
2364 }
2365 } else {
2366 // If the rhs value is > UnsignedMax, fold the comparison. This handles
2367 // +INF and large values.
2368 APFloat unsignedMax(rhs.getSemantics());
2369 unsignedMax.convertFromAPInt(APInt::getMaxValue(intWidth), false,
2370 APFloat::rmNearestTiesToEven);
2371 if (unsignedMax < rhs) { // umax < 13123.0
2372 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ult ||
2373 pred == CmpIPredicate::ule)
2374 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2375 /*width=*/1);
2376 else
2377 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2378 /*width=*/1);
2379 return success();
2380 }
2381 }
2382
2383 if (!isUnsigned) {
2384 // See if the rhs value is < SignedMin.
2385 APFloat signedMin(rhs.getSemantics());
2386 signedMin.convertFromAPInt(APInt::getSignedMinValue(intWidth), true,
2387 APFloat::rmNearestTiesToEven);
2388 if (signedMin > rhs) { // smin > 12312.0
2389 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::sgt ||
2390 pred == CmpIPredicate::sge)
2391 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2392 /*width=*/1);
2393 else
2394 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2395 /*width=*/1);
2396 return success();
2397 }
2398 } else {
2399 // See if the rhs value is < UnsignedMin.
2400 APFloat unsignedMin(rhs.getSemantics());
2401 unsignedMin.convertFromAPInt(APInt::getMinValue(intWidth), false,
2402 APFloat::rmNearestTiesToEven);
2403 if (unsignedMin > rhs) { // umin > 12312.0
2404 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ugt ||
2405 pred == CmpIPredicate::uge)
2406 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2407 /*width=*/1);
2408 else
2409 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2410 /*width=*/1);
2411 return success();
2412 }
2413 }
2414
2415 // Okay, now we know that the FP constant fits in the range [SMIN, SMAX] or
2416 // [0, UMAX], but it may still be fractional. See if it is fractional by
2417 // casting the FP value to the integer value and back, checking for
2418 // equality. Don't do this for zero, because -0.0 is not fractional.
2419 bool ignored;
2420 APSInt rhsInt(intWidth, isUnsigned);
2421 if (APFloat::opInvalidOp ==
2422 rhs.convertToInteger(rhsInt, APFloat::rmTowardZero, &ignored)) {
2423 // Undefined behavior invoked - the destination type can't represent
2424 // the input constant.
2425 return failure();
2426 }
2427
2428 if (!rhs.isZero()) {
2429 APFloat apf(floatTy.getFloatSemantics(),
2430 APInt::getZero(floatTy.getWidth()));
2431 apf.convertFromAPInt(rhsInt, !isUnsigned, APFloat::rmNearestTiesToEven);
2432
2433 bool equal = apf == rhs;
2434 if (!equal) {
2435 // If we had a comparison against a fractional value, we have to adjust
2436 // the compare predicate and sometimes the value. rhsInt is rounded
2437 // towards zero at this point.
2438 switch (pred) {
2439 case CmpIPredicate::ne: // (float)int != 4.4 --> true
2440 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2441 /*width=*/1);
2442 return success();
2443 case CmpIPredicate::eq: // (float)int == 4.4 --> false
2444 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2445 /*width=*/1);
2446 return success();
2447 case CmpIPredicate::ule:
2448 // (float)int <= 4.4 --> int <= 4
2449 // (float)int <= -4.4 --> false
2450 if (rhs.isNegative()) {
2451 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2452 /*width=*/1);
2453 return success();
2454 }
2455 break;
2456 case CmpIPredicate::sle:
2457 // (float)int <= 4.4 --> int <= 4
2458 // (float)int <= -4.4 --> int < -4
2459 if (rhs.isNegative())
2460 pred = CmpIPredicate::slt;
2461 break;
2462 case CmpIPredicate::ult:
2463 // (float)int < -4.4 --> false
2464 // (float)int < 4.4 --> int <= 4
2465 if (rhs.isNegative()) {
2466 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2467 /*width=*/1);
2468 return success();
2469 }
2470 pred = CmpIPredicate::ule;
2471 break;
2472 case CmpIPredicate::slt:
2473 // (float)int < -4.4 --> int < -4
2474 // (float)int < 4.4 --> int <= 4
2475 if (!rhs.isNegative())
2476 pred = CmpIPredicate::sle;
2477 break;
2478 case CmpIPredicate::ugt:
2479 // (float)int > 4.4 --> int > 4
2480 // (float)int > -4.4 --> true
2481 if (rhs.isNegative()) {
2482 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2483 /*width=*/1);
2484 return success();
2485 }
2486 break;
2487 case CmpIPredicate::sgt:
2488 // (float)int > 4.4 --> int > 4
2489 // (float)int > -4.4 --> int >= -4
2490 if (rhs.isNegative())
2491 pred = CmpIPredicate::sge;
2492 break;
2493 case CmpIPredicate::uge:
2494 // (float)int >= -4.4 --> true
2495 // (float)int >= 4.4 --> int > 4
2496 if (rhs.isNegative()) {
2497 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2498 /*width=*/1);
2499 return success();
2500 }
2501 pred = CmpIPredicate::ugt;
2502 break;
2503 case CmpIPredicate::sge:
2504 // (float)int >= -4.4 --> int >= -4
2505 // (float)int >= 4.4 --> int > 4
2506 if (!rhs.isNegative())
2507 pred = CmpIPredicate::sgt;
2508 break;
2509 }
2510 }
2511 }
2512
2513 // Lower this FP comparison into an appropriate integer version of the
2514 // comparison.
2515 rewriter.replaceOpWithNewOp<CmpIOp>(
2516 op, pred, intVal,
2517 ConstantOp::create(rewriter, op.getLoc(), intVal.getType(),
2518 rewriter.getIntegerAttr(intVal.getType(), rhsInt)));
2519 return success();
2520 }
2521};
2522
2523void arith::CmpFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
2524 MLIRContext *context) {
2525 patterns.insert<CmpFIntToFPConst>(context);
2526}
2527
2528//===----------------------------------------------------------------------===//
2529// SelectOp
2530//===----------------------------------------------------------------------===//
2531
2532// select %arg, %c1, %c0 => extui %arg
2533struct SelectToExtUI : public OpRewritePattern<arith::SelectOp> {
2534 using Base::Base;
2535
2536 LogicalResult matchAndRewrite(arith::SelectOp op,
2537 PatternRewriter &rewriter) const override {
2538 // Cannot extui i1 to i1, or i1 to f32
2539 if (!llvm::isa<IntegerType>(op.getType()) || op.getType().isInteger(1))
2540 return failure();
2541
2542 // select %x, c1, %c0 => extui %arg
2543 if (matchPattern(op.getTrueValue(), m_One()) &&
2544 matchPattern(op.getFalseValue(), m_Zero())) {
2545 rewriter.replaceOpWithNewOp<arith::ExtUIOp>(op, op.getType(),
2546 op.getCondition());
2547 return success();
2548 }
2549
2550 // select %x, c0, %c1 => extui (xor %arg, true)
2551 if (matchPattern(op.getTrueValue(), m_Zero()) &&
2552 matchPattern(op.getFalseValue(), m_One())) {
2553 rewriter.replaceOpWithNewOp<arith::ExtUIOp>(
2554 op, op.getType(),
2555 arith::XOrIOp::create(
2556 rewriter, op.getLoc(), op.getCondition(),
2557 arith::ConstantIntOp::create(rewriter, op.getLoc(),
2558 op.getCondition().getType(), 1)));
2559 return success();
2560 }
2561
2562 return failure();
2563 }
2564};
2565
2566void arith::SelectOp::getCanonicalizationPatterns(RewritePatternSet &results,
2567 MLIRContext *context) {
2568 results.add<RedundantSelectFalse, RedundantSelectTrue, SelectNotCond,
2569 SelectI1ToNot, SelectToExtUI>(context);
2570}
2571
2572OpFoldResult arith::SelectOp::fold(FoldAdaptor adaptor) {
2573 Value trueVal = getTrueValue();
2574 Value falseVal = getFalseValue();
2575 if (trueVal == falseVal)
2576 return trueVal;
2577
2578 Value condition = getCondition();
2579
2580 // select true, %0, %1 => %0
2581 if (matchPattern(adaptor.getCondition(), m_One()))
2582 return trueVal;
2583
2584 // select false, %0, %1 => %1
2585 if (matchPattern(adaptor.getCondition(), m_Zero()))
2586 return falseVal;
2587
2588 // If either operand is fully poisoned, return the other.
2589 if (matchPattern(adaptor.getTrueValue(), ub::m_Poison()))
2590 return falseVal;
2591
2592 if (matchPattern(adaptor.getFalseValue(), ub::m_Poison()))
2593 return trueVal;
2594
2595 // select %x, true, false => %x
2596 if (getType().isSignlessInteger(1) &&
2597 matchPattern(adaptor.getTrueValue(), m_One()) &&
2598 matchPattern(adaptor.getFalseValue(), m_Zero()))
2599 return condition;
2600
2601 if (auto cmp = condition.getDefiningOp<arith::CmpIOp>()) {
2602 auto pred = cmp.getPredicate();
2603 if (pred == arith::CmpIPredicate::eq || pred == arith::CmpIPredicate::ne) {
2604 auto cmpLhs = cmp.getLhs();
2605 auto cmpRhs = cmp.getRhs();
2606
2607 // %0 = arith.cmpi eq, %arg0, %arg1
2608 // %1 = arith.select %0, %arg0, %arg1 => %arg1
2609
2610 // %0 = arith.cmpi ne, %arg0, %arg1
2611 // %1 = arith.select %0, %arg0, %arg1 => %arg0
2612
2613 if ((cmpLhs == trueVal && cmpRhs == falseVal) ||
2614 (cmpRhs == trueVal && cmpLhs == falseVal))
2615 return pred == arith::CmpIPredicate::ne ? trueVal : falseVal;
2616 }
2617 }
2618
2619 // Constant-fold constant operands over non-splat constant condition.
2620 // select %cst_vec, %cst0, %cst1 => %cst2
2621 if (auto cond =
2622 dyn_cast_if_present<DenseElementsAttr>(adaptor.getCondition())) {
2623 // DenseElementsAttr by construction always has a static shape.
2624 assert(cond.getType().hasStaticShape() &&
2625 "DenseElementsAttr must have static shape");
2626 if (auto lhs =
2627 dyn_cast_if_present<DenseElementsAttr>(adaptor.getTrueValue())) {
2628 if (auto rhs =
2629 dyn_cast_if_present<DenseElementsAttr>(adaptor.getFalseValue())) {
2630 SmallVector<Attribute> results;
2631 results.reserve(static_cast<size_t>(cond.getNumElements()));
2632 auto condVals = llvm::make_range(cond.value_begin<BoolAttr>(),
2633 cond.value_end<BoolAttr>());
2634 auto lhsVals = llvm::make_range(lhs.value_begin<Attribute>(),
2635 lhs.value_end<Attribute>());
2636 auto rhsVals = llvm::make_range(rhs.value_begin<Attribute>(),
2637 rhs.value_end<Attribute>());
2638
2639 for (auto [condVal, lhsVal, rhsVal] :
2640 llvm::zip_equal(condVals, lhsVals, rhsVals))
2641 results.push_back(condVal.getValue() ? lhsVal : rhsVal);
2642
2643 return DenseElementsAttr::get(lhs.getType(), results);
2644 }
2645 }
2646 }
2647
2648 return nullptr;
2649}
2650
2651ParseResult SelectOp::parse(OpAsmParser &parser, OperationState &result) {
2652 Type conditionType, resultType;
2653 SmallVector<OpAsmParser::UnresolvedOperand, 3> operands;
2654 if (parser.parseOperandList(operands, /*requiredOperandCount=*/3) ||
2655 parser.parseOptionalAttrDict(result.attributes) ||
2656 parser.parseColonType(resultType))
2657 return failure();
2658
2659 // Check for the explicit condition type if this is a masked tensor or vector.
2660 if (succeeded(parser.parseOptionalComma())) {
2661 conditionType = resultType;
2662 if (parser.parseType(resultType))
2663 return failure();
2664 } else {
2665 conditionType = parser.getBuilder().getI1Type();
2666 }
2667
2668 result.addTypes(resultType);
2669 return parser.resolveOperands(operands,
2670 {conditionType, resultType, resultType},
2671 parser.getNameLoc(), result.operands);
2672}
2673
2674void arith::SelectOp::print(OpAsmPrinter &p) {
2675 p << " " << getOperands();
2676 p.printOptionalAttrDict((*this)->getAttrs());
2677 p << " : ";
2678 if (ShapedType condType = dyn_cast<ShapedType>(getCondition().getType()))
2679 p << condType << ", ";
2680 p << getType();
2681}
2682
2683LogicalResult arith::SelectOp::verify() {
2684 Type conditionType = getCondition().getType();
2685 if (conditionType.isSignlessInteger(1))
2686 return success();
2687
2688 // If the result type is a vector or tensor, the type can be a mask with the
2689 // same elements.
2690 Type resultType = getType();
2691 if (!llvm::isa<TensorType, VectorType>(resultType))
2692 return emitOpError() << "expected condition to be a signless i1, but got "
2693 << conditionType;
2694 Type shapedConditionType = getI1SameShape(resultType);
2695 if (conditionType != shapedConditionType) {
2696 return emitOpError() << "expected condition type to have the same shape "
2697 "as the result type, expected "
2698 << shapedConditionType << ", but got "
2699 << conditionType;
2700 }
2701 return success();
2702}
2703//===----------------------------------------------------------------------===//
2704// ShLIOp
2705//===----------------------------------------------------------------------===//
2706
2707OpFoldResult arith::ShLIOp::fold(FoldAdaptor adaptor) {
2708 // shli(x, 0) -> x
2709 if (matchPattern(adaptor.getRhs(), m_Zero()))
2710 return getLhs();
2711 // Don't fold if shifting more or equal than the bit width.
2712 bool bounded = false;
2714 adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
2715 bounded = b.ult(b.getBitWidth());
2716 return a.shl(b);
2717 });
2718 return bounded ? result : Attribute();
2719}
2720
2721//===----------------------------------------------------------------------===//
2722// ShRUIOp
2723//===----------------------------------------------------------------------===//
2724
2725OpFoldResult arith::ShRUIOp::fold(FoldAdaptor adaptor) {
2726 // shrui(x, 0) -> x
2727 if (matchPattern(adaptor.getRhs(), m_Zero()))
2728 return getLhs();
2729 // Don't fold if shifting more or equal than the bit width.
2730 bool bounded = false;
2732 adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
2733 bounded = b.ult(b.getBitWidth());
2734 return a.lshr(b);
2735 });
2736 return bounded ? result : Attribute();
2737}
2738
2739//===----------------------------------------------------------------------===//
2740// ShRSIOp
2741//===----------------------------------------------------------------------===//
2742
2743OpFoldResult arith::ShRSIOp::fold(FoldAdaptor adaptor) {
2744 // shrsi(x, 0) -> x
2745 if (matchPattern(adaptor.getRhs(), m_Zero()))
2746 return getLhs();
2747 // Don't fold if shifting more or equal than the bit width.
2748 bool bounded = false;
2750 adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
2751 bounded = b.ult(b.getBitWidth());
2752 return a.ashr(b);
2753 });
2754 return bounded ? result : Attribute();
2755}
2756
2757//===----------------------------------------------------------------------===//
2758// Atomic Enum
2759//===----------------------------------------------------------------------===//
2760
2761/// Returns the identity value attribute associated with an AtomicRMWKind op.
2762TypedAttr mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
2763 OpBuilder &builder, Location loc,
2764 bool useOnlyFiniteValue) {
2765 switch (kind) {
2766 case AtomicRMWKind::maximumf: {
2767 const llvm::fltSemantics &semantic =
2768 llvm::cast<FloatType>(resultType).getFloatSemantics();
2769 APFloat identity = useOnlyFiniteValue
2770 ? APFloat::getLargest(semantic, /*Negative=*/true)
2771 : APFloat::getInf(semantic, /*Negative=*/true);
2772 return builder.getFloatAttr(resultType, identity);
2773 }
2774 case AtomicRMWKind::maxnumf: {
2775 const llvm::fltSemantics &semantic =
2776 llvm::cast<FloatType>(resultType).getFloatSemantics();
2777 APFloat identity = APFloat::getNaN(semantic, /*Negative=*/true);
2778 return builder.getFloatAttr(resultType, identity);
2779 }
2780 case AtomicRMWKind::addf:
2781 case AtomicRMWKind::addi:
2782 case AtomicRMWKind::maxu:
2783 case AtomicRMWKind::ori:
2784 case AtomicRMWKind::xori:
2785 return builder.getZeroAttr(resultType);
2786 case AtomicRMWKind::andi:
2787 return builder.getIntegerAttr(
2788 resultType,
2789 APInt::getAllOnes(llvm::cast<IntegerType>(resultType).getWidth()));
2790 case AtomicRMWKind::maxs:
2791 return builder.getIntegerAttr(
2792 resultType, APInt::getSignedMinValue(
2793 llvm::cast<IntegerType>(resultType).getWidth()));
2794 case AtomicRMWKind::minimumf: {
2795 const llvm::fltSemantics &semantic =
2796 llvm::cast<FloatType>(resultType).getFloatSemantics();
2797 APFloat identity = useOnlyFiniteValue
2798 ? APFloat::getLargest(semantic, /*Negative=*/false)
2799 : APFloat::getInf(semantic, /*Negative=*/false);
2800
2801 return builder.getFloatAttr(resultType, identity);
2802 }
2803 case AtomicRMWKind::minnumf: {
2804 const llvm::fltSemantics &semantic =
2805 llvm::cast<FloatType>(resultType).getFloatSemantics();
2806 APFloat identity = APFloat::getNaN(semantic, /*Negative=*/false);
2807 return builder.getFloatAttr(resultType, identity);
2808 }
2809 case AtomicRMWKind::mins:
2810 return builder.getIntegerAttr(
2811 resultType, APInt::getSignedMaxValue(
2812 llvm::cast<IntegerType>(resultType).getWidth()));
2813 case AtomicRMWKind::minu:
2814 return builder.getIntegerAttr(
2815 resultType,
2816 APInt::getMaxValue(llvm::cast<IntegerType>(resultType).getWidth()));
2817 case AtomicRMWKind::muli:
2818 return builder.getIntegerAttr(resultType, 1);
2819 case AtomicRMWKind::mulf:
2820 return builder.getFloatAttr(resultType, 1);
2821 // TODO: Add remaining reduction operations.
2822 default:
2823 (void)emitOptionalError(loc, "Reduction operation type not supported");
2824 break;
2825 }
2826 return nullptr;
2827}
2828
2829/// Returns the identity numeric value of the given op.
2830std::optional<TypedAttr> mlir::arith::getNeutralElement(Operation *op) {
2831 std::optional<AtomicRMWKind> maybeKind =
2833 // Floating-point operations.
2834 .Case([](arith::AddFOp op) { return AtomicRMWKind::addf; })
2835 .Case([](arith::MulFOp op) { return AtomicRMWKind::mulf; })
2836 .Case([](arith::MaximumFOp op) { return AtomicRMWKind::maximumf; })
2837 .Case([](arith::MinimumFOp op) { return AtomicRMWKind::minimumf; })
2838 .Case([](arith::MaxNumFOp op) { return AtomicRMWKind::maxnumf; })
2839 .Case([](arith::MinNumFOp op) { return AtomicRMWKind::minnumf; })
2840 // Integer operations.
2841 .Case([](arith::AddIOp op) { return AtomicRMWKind::addi; })
2842 .Case([](arith::OrIOp op) { return AtomicRMWKind::ori; })
2843 .Case([](arith::XOrIOp op) { return AtomicRMWKind::xori; })
2844 .Case([](arith::AndIOp op) { return AtomicRMWKind::andi; })
2845 .Case([](arith::MaxUIOp op) { return AtomicRMWKind::maxu; })
2846 .Case([](arith::MinUIOp op) { return AtomicRMWKind::minu; })
2847 .Case([](arith::MaxSIOp op) { return AtomicRMWKind::maxs; })
2848 .Case([](arith::MinSIOp op) { return AtomicRMWKind::mins; })
2849 .Case([](arith::MulIOp op) { return AtomicRMWKind::muli; })
2850 .Default(std::nullopt);
2851 if (!maybeKind) {
2852 return std::nullopt;
2853 }
2854
2855 bool useOnlyFiniteValue = false;
2856 auto fmfOpInterface = dyn_cast<ArithFastMathInterface>(op);
2857 if (fmfOpInterface) {
2858 arith::FastMathFlagsAttr fmfAttr = fmfOpInterface.getFastMathFlagsAttr();
2859 useOnlyFiniteValue =
2860 bitEnumContainsAny(fmfAttr.getValue(), arith::FastMathFlags::ninf);
2861 }
2862
2863 // Builder only used as helper for attribute creation.
2864 OpBuilder b(op->getContext());
2865 Type resultType = op->getResult(0).getType();
2866
2867 return getIdentityValueAttr(*maybeKind, resultType, b, op->getLoc(),
2868 useOnlyFiniteValue);
2869}
2870
2871/// Returns the identity value associated with an AtomicRMWKind op.
2872Value mlir::arith::getIdentityValue(AtomicRMWKind op, Type resultType,
2873 OpBuilder &builder, Location loc,
2874 bool useOnlyFiniteValue) {
2875 if (auto attr =
2876getIdentityValueAttr(op, resultType, builder, loc, useOnlyFiniteValue))
2877 return arith::ConstantOp::create(builder, loc, attr);
2878 return {};
2879}
2880
2881/// Return the value obtained by applying the reduction operation kind
2882/// associated with a binary AtomicRMWKind op to `lhs` and `rhs`.
2884 Location loc, Value lhs, Value rhs) {
2885 switch (op) {
2886 case AtomicRMWKind::addf:
2887 return arith::AddFOp::create(builder, loc, lhs, rhs);
2888 case AtomicRMWKind::addi:
2889 return arith::AddIOp::create(builder, loc, lhs, rhs);
2890 case AtomicRMWKind::mulf:
2891 return arith::MulFOp::create(builder, loc, lhs, rhs);
2892 case AtomicRMWKind::muli:
2893 return arith::MulIOp::create(builder, loc, lhs, rhs);
2894 case AtomicRMWKind::maximumf:
2895 return arith::MaximumFOp::create(builder, loc, lhs, rhs);
2896 case AtomicRMWKind::minimumf:
2897 return arith::MinimumFOp::create(builder, loc, lhs, rhs);
2898 case AtomicRMWKind::maxnumf:
2899 return arith::MaxNumFOp::create(builder, loc, lhs, rhs);
2900 case AtomicRMWKind::minnumf:
2901 return arith::MinNumFOp::create(builder, loc, lhs, rhs);
2902 case AtomicRMWKind::maxs:
2903 return arith::MaxSIOp::create(builder, loc, lhs, rhs);
2904 case AtomicRMWKind::mins:
2905 return arith::MinSIOp::create(builder, loc, lhs, rhs);
2906 case AtomicRMWKind::maxu:
2907 return arith::MaxUIOp::create(builder, loc, lhs, rhs);
2908 case AtomicRMWKind::minu:
2909 return arith::MinUIOp::create(builder, loc, lhs, rhs);
2910 case AtomicRMWKind::ori:
2911 return arith::OrIOp::create(builder, loc, lhs, rhs);
2912 case AtomicRMWKind::andi:
2913 return arith::AndIOp::create(builder, loc, lhs, rhs);
2914 case AtomicRMWKind::xori:
2915 return arith::XOrIOp::create(builder, loc, lhs, rhs);
2916 // TODO: Add remaining reduction operations.
2917 default:
2918 (void)emitOptionalError(loc, "Reduction operation type not supported");
2919 break;
2920 }
2921 return nullptr;
2922}
2923
2924//===----------------------------------------------------------------------===//
2925// TableGen'd op method definitions
2926//===----------------------------------------------------------------------===//
2927
2928#define GET_OP_CLASSES
2929#include "mlir/Dialect/Arith/IR/ArithOps.cpp.inc"
2930
2931//===----------------------------------------------------------------------===//
2932// TableGen'd enum attribute definitions
2933//===----------------------------------------------------------------------===//
2934
2935#include "mlir/Dialect/Arith/IR/ArithOpsEnums.cpp.inc"
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static Speculation::Speculatability getDivUISpeculatability(Value divisor)
Returns whether an unsigned division by divisor is speculatable.
Definition ArithOps.cpp:721
static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs)
Validate a cast that changes the width of a type.
static IntegerAttr mulIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
Definition ArithOps.cpp:61
static IntegerOverflowFlagsAttr mergeOverflowFlags(IntegerOverflowFlagsAttr val1, IntegerOverflowFlagsAttr val2)
Definition ArithOps.cpp:68
static llvm::RoundingMode convertArithRoundingModeToLLVMIR(RoundingMode roundingMode)
Equivalent to convertRoundingModeToLLVM(convertArithRoundingModeToLLVM(roundingMode)).
Definition ArithOps.cpp:108
static Type getTypeIfLike(Type type)
Get allowed underlying types for vectors and tensors.
static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate)
Returns true if the predicate is true for two equal operands.
static Value foldDivMul(Value lhs, Value rhs, arith::IntegerOverflowFlags ovfFlags)
Fold (a * b) / b -> a
Definition ArithOps.cpp:682
static bool hasSameEncoding(Type typeA, Type typeB)
Return false if both types are ranked tensor with mismatching encoding.
static Type getUnderlyingType(Type type, type_list< ShapedTypes... >, type_list< ElementTypes... >)
Returns a non-null type only if the provided type is one of the allowed types or one of the allowed s...
static APInt signedCeilNonnegInputs(const APInt &a, const APInt &b, bool &overflow)
Definition ArithOps.cpp:781
static std::optional< int64_t > getIntegerWidth(Type t)
static Speculation::Speculatability getDivSISpeculatability(Value divisor)
Returns whether a signed division by divisor is speculatable.
Definition ArithOps.cpp:763
static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
Definition ArithOps.cpp:51
static Attribute getBoolAttribute(Type type, bool value)
Definition ArithOps.cpp:149
static bool areIndexCastCompatible(TypeRange inputs, TypeRange outputs)
static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs)
static LogicalResult verifyExtOp(Op op)
static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
Definition ArithOps.cpp:56
static int64_t getScalarOrElementWidth(Type type)
Definition ArithOps.cpp:129
static FailureOr< APFloat > convertFloatValue(APFloat sourceValue, const llvm::fltSemantics &targetSemantics, llvm::RoundingMode roundingMode=llvm::RoundingMode::NearestTiesToEven)
Attempts to convert sourceValue to an APFloat value with targetSemantics and roundingMode,...
static Value foldAndIofAndI(arith::AndIOp op)
Fold and(a, and(a, b)) to and(a, b)
Definition ArithOps.cpp:975
static Type getTypeIfLikeOrMemRef(Type type)
Get allowed underlying types for vectors, tensors, and memrefs.
static Type getI1SameShape(Type type)
Return the type of the same shape (scalar, vector or tensor) containing i1.
Definition ArithOps.cpp:173
static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs)
std::tuple< Types... > * type_list
static IntegerAttr applyToIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs, function_ref< APInt(const APInt &, const APInt &)> binFn)
Definition ArithOps.cpp:42
static APInt calculateUnsignedOverflow(const APInt &sum, const APInt &operand)
Definition ArithOps.cpp:437
static FailureOr< APInt > getIntOrSplatIntValue(Attribute attr)
Definition ArithOps.cpp:141
static LogicalResult verifyTruncateOp(Op op)
lhs
static Type getElementType(Type type)
Determine the element type of type.
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
if(!isCopyOut)
b getContext())
#define mul(a, b)
#define add(a, b)
#define div(a, b)
LogicalResult matchAndRewrite(CmpFOp op, PatternRewriter &rewriter) const override
static CmpIPredicate convertToIntegerPredicate(CmpFPredicate pred, bool isUnsigned)
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
Attributes are known-constant values of operations.
Definition Attributes.h:25
static BoolAttr get(MLIRContext *context, bool value)
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:112
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:232
FloatAttr getFloatAttr(Type type, double value)
Definition Builders.cpp:258
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:71
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition Builders.h:93
TypedAttr getZeroAttr(Type type)
Definition Builders.cpp:328
IntegerType getI1Type()
Definition Builders.cpp:57
IndexType getIndexType()
Definition Builders.cpp:55
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Definition Builders.h:632
Location getLoc() const
Accessors for the implied location.
Definition Builders.h:665
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
This class helps build Operations.
Definition Builders.h:209
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition Builders.cpp:461
This class represents a single result from folding an operation.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:436
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:244
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:237
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:40
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition Types.cpp:35
bool isSignlessInteger() const
Return true if this is a signless integer type (with the specified width).
Definition Types.cpp:66
bool isIndex() const
Definition Types.cpp:56
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition Types.cpp:118
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:124
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
Specialization of arith.constant op that returns a floating point value.
Definition Arith.h:92
static ConstantFloatOp create(OpBuilder &builder, Location location, FloatType type, const APFloat &value)
Definition ArithOps.cpp:334
static bool classof(Operation *op)
Definition ArithOps.cpp:351
static void build(OpBuilder &builder, OperationState &result, FloatType type, const APFloat &value)
Build a constant float op that produces a float of the specified type.
Definition ArithOps.cpp:328
Specialization of arith.constant op that returns an integer of index type.
Definition Arith.h:113
static void build(OpBuilder &builder, OperationState &result, int64_t value)
Build a constant int op that produces an index.
Definition ArithOps.cpp:357
static bool classof(Operation *op)
Definition ArithOps.cpp:378
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:363
Specialization of arith.constant op that returns an integer value.
Definition Arith.h:54
static ConstantIntOp create(OpBuilder &builder, Location location, int64_t value, unsigned width)
Definition ArithOps.cpp:262
static void build(OpBuilder &builder, OperationState &result, int64_t value, unsigned width)
Build a constant int op that produces an integer of the specified width.
Definition ArithOps.cpp:255
static bool classof(Operation *op)
Definition ArithOps.cpp:322
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
std::optional< TypedAttr > getNeutralElement(Operation *op)
Return the identity numeric value associated to the give op.
bool applyCmpPredicate(arith::CmpIPredicate predicate, const APInt &lhs, const APInt &rhs)
Compute lhs pred rhs, where pred is one of the known integer comparison predicates.
TypedAttr getIdentityValueAttr(AtomicRMWKind kind, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value attribute associated with an AtomicRMWKind op.
Value getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc, Value lhs, Value rhs)
Returns the value obtained by applying the reduction operation kind associated with a binary AtomicRM...
Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value associated with an AtomicRMWKind op.
arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred)
Invert an integer comparison predicate.
Definition ArithOps.cpp:75
Value getZeroConstant(OpBuilder &builder, Location loc, Type type)
Creates an arith.constant operation with a zero value of type type.
Definition ArithOps.cpp:384
auto m_Val(Value v)
Definition Matchers.h:539
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
detail::poison_attr_matcher m_Poison()
Matches a poison constant (any attribute implementing PoisonAttrInterface).
Definition UBMatchers.h:46
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition Matchers.h:527
detail::constant_float_predicate_matcher m_NaNFloat()
Matches a constant scalar / vector splat / tensor splat float ones.
Definition Matchers.h:421
LogicalResult verifyCompatibleShapes(TypeRange types1, TypeRange types2)
Returns success if the given two arrays have the same number of elements and each pair wise entries h...
Attribute constFoldCastOp(ArrayRef< Attribute > operands, Type resType, CalculationT &&calculate)
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
Attribute constFoldBinaryOp(ArrayRef< Attribute > operands, Type resultType, CalculationT &&calculate)
detail::constant_int_range_predicate_matcher m_IntRangeWithoutNegOneS()
Matches a constant scalar / vector splat / tensor splat integer or a signed integer range that does n...
Definition Matchers.h:471
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
detail::constant_float_predicate_matcher m_PosZeroFloat()
Matches a constant scalar / vector splat / tensor splat float positive zero.
Definition Matchers.h:404
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Definition Matchers.h:442
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
detail::constant_float_predicate_matcher m_AnyZeroFloat()
Matches a constant scalar / vector splat / tensor splat float (both positive and negative) zero.
Definition Matchers.h:399
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
Definition Matchers.h:478
detail::constant_float_predicate_matcher m_NegInfFloat()
Matches a constant scalar / vector splat / tensor splat float negative infinity.
Definition Matchers.h:435
detail::constant_float_predicate_matcher m_NegZeroFloat()
Matches a constant scalar / vector splat / tensor splat float negative zero.
Definition Matchers.h:409
detail::constant_int_range_predicate_matcher m_IntRangeWithoutZeroS()
Matches a constant scalar / vector splat / tensor splat integer or a signed integer range that does n...
Definition Matchers.h:462
detail::op_matcher< OpClass > m_Op()
Matches the given OpClass.
Definition Matchers.h:484
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
Attribute constFoldUnaryOp(ArrayRef< Attribute > operands, Type resultType, CalculationT &&calculate)
detail::constant_float_predicate_matcher m_PosInfFloat()
Matches a constant scalar / vector splat / tensor splat float positive infinity.
Definition Matchers.h:427
llvm::function_ref< Fn > function_ref
Definition LLVM.h:147
detail::constant_float_predicate_matcher m_OneFloat()
Matches a constant scalar / vector splat / tensor splat float ones.
Definition Matchers.h:414
detail::constant_int_range_predicate_matcher m_IntRangeWithoutZeroU()
Matches a constant scalar / vector splat / tensor splat integer or a unsigned integer range that does...
Definition Matchers.h:455
LogicalResult matchAndRewrite(arith::SelectOp op, PatternRewriter &rewriter) const override
OpRewritePattern Base
Type alias to allow derived classes to inherit constructors with using Base::Base;.
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
This represents an operation in an abstracted form, suitable for use with the builder APIs.