MLIR 22.0.0git
ArithOps.cpp
Go to the documentation of this file.
1//===- ArithOps.cpp - MLIR Arith dialect ops implementation -----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include <cassert>
10#include <cstdint>
11#include <functional>
12#include <utility>
13
17#include "mlir/IR/Builders.h"
20#include "mlir/IR/Matchers.h"
25
26#include "llvm/ADT/APFloat.h"
27#include "llvm/ADT/APInt.h"
28#include "llvm/ADT/APSInt.h"
29#include "llvm/ADT/FloatingPointMode.h"
30#include "llvm/ADT/STLExtras.h"
31#include "llvm/ADT/SmallVector.h"
32#include "llvm/ADT/TypeSwitch.h"
33
34using namespace mlir;
35using namespace mlir::arith;
36
37//===----------------------------------------------------------------------===//
38// Pattern helpers
39//===----------------------------------------------------------------------===//
40
41static IntegerAttr
44 function_ref<APInt(const APInt &, const APInt &)> binFn) {
45 APInt lhsVal = llvm::cast<IntegerAttr>(lhs).getValue();
46 APInt rhsVal = llvm::cast<IntegerAttr>(rhs).getValue();
47 APInt value = binFn(lhsVal, rhsVal);
48 return IntegerAttr::get(res.getType(), value);
49}
50
51static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res,
53 return applyToIntegerAttrs(builder, res, lhs, rhs, std::plus<APInt>());
54}
55
56static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res,
58 return applyToIntegerAttrs(builder, res, lhs, rhs, std::minus<APInt>());
59}
60
61static IntegerAttr mulIntegerAttrs(PatternRewriter &builder, Value res,
63 return applyToIntegerAttrs(builder, res, lhs, rhs, std::multiplies<APInt>());
64}
65
66// Merge overflow flags from 2 ops, selecting the most conservative combination.
67static IntegerOverflowFlagsAttr
68mergeOverflowFlags(IntegerOverflowFlagsAttr val1,
69 IntegerOverflowFlagsAttr val2) {
70 return IntegerOverflowFlagsAttr::get(val1.getContext(),
71 val1.getValue() & val2.getValue());
72}
73
74/// Invert an integer comparison predicate.
75arith::CmpIPredicate arith::invertPredicate(arith::CmpIPredicate pred) {
76 switch (pred) {
77 case arith::CmpIPredicate::eq:
78 return arith::CmpIPredicate::ne;
79 case arith::CmpIPredicate::ne:
80 return arith::CmpIPredicate::eq;
81 case arith::CmpIPredicate::slt:
82 return arith::CmpIPredicate::sge;
83 case arith::CmpIPredicate::sle:
84 return arith::CmpIPredicate::sgt;
85 case arith::CmpIPredicate::sgt:
86 return arith::CmpIPredicate::sle;
87 case arith::CmpIPredicate::sge:
88 return arith::CmpIPredicate::slt;
89 case arith::CmpIPredicate::ult:
90 return arith::CmpIPredicate::uge;
91 case arith::CmpIPredicate::ule:
92 return arith::CmpIPredicate::ugt;
93 case arith::CmpIPredicate::ugt:
94 return arith::CmpIPredicate::ule;
95 case arith::CmpIPredicate::uge:
96 return arith::CmpIPredicate::ult;
97 }
98 llvm_unreachable("unknown cmpi predicate kind");
99}
100
101/// Equivalent to
102/// convertRoundingModeToLLVM(convertArithRoundingModeToLLVM(roundingMode)).
103///
104/// Not possible to implement as chain of calls as this would introduce a
105/// circular dependency with MLIRArithAttrToLLVMConversion and make arith depend
106/// on the LLVM dialect and on translation to LLVM.
107static llvm::RoundingMode
108convertArithRoundingModeToLLVMIR(RoundingMode roundingMode) {
109 switch (roundingMode) {
110 case RoundingMode::downward:
111 return llvm::RoundingMode::TowardNegative;
112 case RoundingMode::to_nearest_away:
113 return llvm::RoundingMode::NearestTiesToAway;
114 case RoundingMode::to_nearest_even:
115 return llvm::RoundingMode::NearestTiesToEven;
116 case RoundingMode::toward_zero:
117 return llvm::RoundingMode::TowardZero;
118 case RoundingMode::upward:
119 return llvm::RoundingMode::TowardPositive;
120 }
121 llvm_unreachable("Unhandled rounding mode");
122}
123
124static arith::CmpIPredicateAttr invertPredicate(arith::CmpIPredicateAttr pred) {
125 return arith::CmpIPredicateAttr::get(pred.getContext(),
126 invertPredicate(pred.getValue()));
127}
128
130 Type elemTy = getElementTypeOrSelf(type);
131 if (elemTy.isIntOrFloat())
132 return elemTy.getIntOrFloatBitWidth();
133
134 return -1;
135}
136
138 return getScalarOrElementWidth(value.getType());
139}
140
141static FailureOr<APInt> getIntOrSplatIntValue(Attribute attr) {
142 APInt value;
143 if (matchPattern(attr, m_ConstantInt(&value)))
144 return value;
145
146 return failure();
147}
148
149static Attribute getBoolAttribute(Type type, bool value) {
150 auto boolAttr = BoolAttr::get(type.getContext(), value);
151 ShapedType shapedType = dyn_cast_or_null<ShapedType>(type);
152 if (!shapedType)
153 return boolAttr;
154 return DenseElementsAttr::get(shapedType, boolAttr);
155}
156
157//===----------------------------------------------------------------------===//
158// TableGen'd canonicalization patterns
159//===----------------------------------------------------------------------===//
160
161namespace {
162#include "ArithCanonicalization.inc"
163} // namespace
164
165//===----------------------------------------------------------------------===//
166// Common helpers
167//===----------------------------------------------------------------------===//
168
169/// Return the type of the same shape (scalar, vector or tensor) containing i1.
171 auto i1Type = IntegerType::get(type.getContext(), 1);
172 if (auto shapedType = dyn_cast<ShapedType>(type))
173 return shapedType.cloneWith(std::nullopt, i1Type);
174 if (llvm::isa<UnrankedTensorType>(type))
175 return UnrankedTensorType::get(i1Type);
176 return i1Type;
177}
178
179//===----------------------------------------------------------------------===//
180// ConstantOp
181//===----------------------------------------------------------------------===//
182
183void arith::ConstantOp::getAsmResultNames(
184 function_ref<void(Value, StringRef)> setNameFn) {
185 auto type = getType();
186 if (auto intCst = dyn_cast<IntegerAttr>(getValue())) {
187 auto intType = dyn_cast<IntegerType>(type);
188
189 // Sugar i1 constants with 'true' and 'false'.
190 if (intType && intType.getWidth() == 1)
191 return setNameFn(getResult(), (intCst.getInt() ? "true" : "false"));
192
193 // Otherwise, build a complex name with the value and type.
194 SmallString<32> specialNameBuffer;
195 llvm::raw_svector_ostream specialName(specialNameBuffer);
196 specialName << 'c' << intCst.getValue();
197 if (intType)
198 specialName << '_' << type;
199 setNameFn(getResult(), specialName.str());
200 } else {
201 setNameFn(getResult(), "cst");
202 }
203}
204
205/// TODO: disallow arith.constant to return anything other than signless integer
206/// or float like.
207LogicalResult arith::ConstantOp::verify() {
208 auto type = getType();
209 // Integer values must be signless.
210 if (llvm::isa<IntegerType>(type) &&
211 !llvm::cast<IntegerType>(type).isSignless())
212 return emitOpError("integer return type must be signless");
213 // Any float or elements attribute are acceptable.
214 if (!llvm::isa<IntegerAttr, FloatAttr, ElementsAttr>(getValue())) {
215 return emitOpError(
216 "value must be an integer, float, or elements attribute");
217 }
218
219 // Note, we could relax this for vectors with 1 scalable dim, e.g.:
220 // * arith.constant dense<[[3, 3], [1, 1]]> : vector<2 x [2] x i32>
221 // However, this would most likely require updating the lowerings to LLVM.
222 if (isa<ScalableVectorType>(type) && !isa<SplatElementsAttr>(getValue()))
223 return emitOpError(
224 "intializing scalable vectors with elements attribute is not supported"
225 " unless it's a vector splat");
226 return success();
227}
228
229bool arith::ConstantOp::isBuildableWith(Attribute value, Type type) {
230 // The value's type must be the same as the provided type.
231 auto typedAttr = dyn_cast<TypedAttr>(value);
232 if (!typedAttr || typedAttr.getType() != type)
233 return false;
234 // Integer values must be signless.
235 if (llvm::isa<IntegerType>(type) &&
236 !llvm::cast<IntegerType>(type).isSignless())
237 return false;
238 // Integer, float, and element attributes are buildable.
239 return llvm::isa<IntegerAttr, FloatAttr, ElementsAttr>(value);
240}
241
242ConstantOp arith::ConstantOp::materialize(OpBuilder &builder, Attribute value,
243 Type type, Location loc) {
244 if (isBuildableWith(value, type))
245 return arith::ConstantOp::create(builder, loc, cast<TypedAttr>(value));
246 return nullptr;
247}
248
249OpFoldResult arith::ConstantOp::fold(FoldAdaptor adaptor) { return getValue(); }
250
252 int64_t value, unsigned width) {
253 auto type = builder.getIntegerType(width);
254 arith::ConstantOp::build(builder, result, type,
255 builder.getIntegerAttr(type, value));
256}
257
259 Location location,
261 unsigned width) {
262 mlir::OperationState state(location, getOperationName());
263 build(builder, state, value, width);
264 auto result = dyn_cast<ConstantIntOp>(builder.create(state));
265 assert(result && "builder didn't return the right type");
266 return result;
267}
268
271 unsigned width) {
272 return create(builder, builder.getLoc(), value, width);
273}
274
276 Type type, int64_t value) {
277 arith::ConstantOp::build(builder, result, type,
278 builder.getIntegerAttr(type, value));
279}
280
282 Location location, Type type,
283 int64_t value) {
284 mlir::OperationState state(location, getOperationName());
285 build(builder, state, type, value);
286 auto result = dyn_cast<ConstantIntOp>(builder.create(state));
287 assert(result && "builder didn't return the right type");
288 return result;
289}
290
292 Type type, int64_t value) {
293 return create(builder, builder.getLoc(), type, value);
294}
295
297 Type type, const APInt &value) {
298 arith::ConstantOp::build(builder, result, type,
299 builder.getIntegerAttr(type, value));
300}
301
303 Location location, Type type,
304 const APInt &value) {
305 mlir::OperationState state(location, getOperationName());
306 build(builder, state, type, value);
307 auto result = dyn_cast<ConstantIntOp>(builder.create(state));
308 assert(result && "builder didn't return the right type");
309 return result;
310}
311
313 Type type,
314 const APInt &value) {
315 return create(builder, builder.getLoc(), type, value);
316}
317
319 if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
320 return constOp.getType().isSignlessInteger();
321 return false;
322}
323
325 FloatType type, const APFloat &value) {
326 arith::ConstantOp::build(builder, result, type,
327 builder.getFloatAttr(type, value));
328}
329
331 Location location,
332 FloatType type,
333 const APFloat &value) {
334 mlir::OperationState state(location, getOperationName());
335 build(builder, state, type, value);
336 auto result = dyn_cast<ConstantFloatOp>(builder.create(state));
337 assert(result && "builder didn't return the right type");
338 return result;
339}
340
343 const APFloat &value) {
344 return create(builder, builder.getLoc(), type, value);
345}
346
348 if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
349 return llvm::isa<FloatType>(constOp.getType());
350 return false;
351}
352
354 int64_t value) {
355 arith::ConstantOp::build(builder, result, builder.getIndexType(),
356 builder.getIndexAttr(value));
357}
358
360 Location location,
361 int64_t value) {
362 mlir::OperationState state(location, getOperationName());
363 build(builder, state, value);
364 auto result = dyn_cast<ConstantIndexOp>(builder.create(state));
365 assert(result && "builder didn't return the right type");
366 return result;
367}
368
371 return create(builder, builder.getLoc(), value);
372}
373
375 if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
376 return constOp.getType().isIndex();
377 return false;
378}
379
381 Type type) {
382 // TODO: Incorporate this check to `FloatAttr::get*`.
383 assert(!isa<Float8E8M0FNUType>(getElementTypeOrSelf(type)) &&
384 "type doesn't have a zero representation");
385 TypedAttr zeroAttr = builder.getZeroAttr(type);
386 assert(zeroAttr && "unsupported type for zero attribute");
387 return arith::ConstantOp::create(builder, loc, zeroAttr);
388}
389
390//===----------------------------------------------------------------------===//
391// AddIOp
392//===----------------------------------------------------------------------===//
393
394OpFoldResult arith::AddIOp::fold(FoldAdaptor adaptor) {
395 // addi(x, 0) -> x
396 if (matchPattern(adaptor.getRhs(), m_Zero()))
397 return getLhs();
398
399 // addi(subi(a, b), b) -> a
400 if (auto sub = getLhs().getDefiningOp<SubIOp>())
401 if (getRhs() == sub.getRhs())
402 return sub.getLhs();
403
404 // addi(b, subi(a, b)) -> a
405 if (auto sub = getRhs().getDefiningOp<SubIOp>())
406 if (getLhs() == sub.getRhs())
407 return sub.getLhs();
408
410 adaptor.getOperands(),
411 [](APInt a, const APInt &b) { return std::move(a) + b; });
412}
413
414void arith::AddIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
415 MLIRContext *context) {
416 patterns.add<AddIAddConstant, AddISubConstantRHS, AddISubConstantLHS,
417 AddIMulNegativeOneRhs, AddIMulNegativeOneLhs>(context);
418}
419
420//===----------------------------------------------------------------------===//
421// AddUIExtendedOp
422//===----------------------------------------------------------------------===//
423
424std::optional<SmallVector<int64_t, 4>>
425arith::AddUIExtendedOp::getShapeForUnroll() {
426 if (auto vt = dyn_cast<VectorType>(getType(0)))
427 return llvm::to_vector<4>(vt.getShape());
428 return std::nullopt;
429}
430
431// Returns the overflow bit, assuming that `sum` is the result of unsigned
432// addition of `operand` and another number.
433static APInt calculateUnsignedOverflow(const APInt &sum, const APInt &operand) {
434 return sum.ult(operand) ? APInt::getAllOnes(1) : APInt::getZero(1);
435}
436
437LogicalResult
438arith::AddUIExtendedOp::fold(FoldAdaptor adaptor,
439 SmallVectorImpl<OpFoldResult> &results) {
440 Type overflowTy = getOverflow().getType();
441 // addui_extended(x, 0) -> x, false
442 if (matchPattern(getRhs(), m_Zero())) {
443 Builder builder(getContext());
444 auto falseValue = builder.getZeroAttr(overflowTy);
445
446 results.push_back(getLhs());
447 results.push_back(falseValue);
448 return success();
449 }
450
451 // addui_extended(constant_a, constant_b) -> constant_sum, constant_carry
452 // Let the `constFoldBinaryOp` utility attempt to fold the sum of both
453 // operands. If that succeeds, calculate the overflow bit based on the sum
454 // and the first (constant) operand, `lhs`.
455 if (Attribute sumAttr = constFoldBinaryOp<IntegerAttr>(
456 adaptor.getOperands(),
457 [](APInt a, const APInt &b) { return std::move(a) + b; })) {
458 Attribute overflowAttr = constFoldBinaryOp<IntegerAttr>(
459 ArrayRef({sumAttr, adaptor.getLhs()}),
460 getI1SameShape(llvm::cast<TypedAttr>(sumAttr).getType()),
462 if (!overflowAttr)
463 return failure();
464
465 results.push_back(sumAttr);
466 results.push_back(overflowAttr);
467 return success();
468 }
469
470 return failure();
471}
472
473void arith::AddUIExtendedOp::getCanonicalizationPatterns(
474 RewritePatternSet &patterns, MLIRContext *context) {
475 patterns.add<AddUIExtendedToAddI>(context);
476}
477
478//===----------------------------------------------------------------------===//
479// SubIOp
480//===----------------------------------------------------------------------===//
481
482OpFoldResult arith::SubIOp::fold(FoldAdaptor adaptor) {
483 // subi(x,x) -> 0
484 if (getOperand(0) == getOperand(1)) {
485 auto shapedType = dyn_cast<ShapedType>(getType());
486 // We can't generate a constant with a dynamic shaped tensor.
487 if (!shapedType || shapedType.hasStaticShape())
488 return Builder(getContext()).getZeroAttr(getType());
489 }
490 // subi(x,0) -> x
491 if (matchPattern(adaptor.getRhs(), m_Zero()))
492 return getLhs();
493
494 if (auto add = getLhs().getDefiningOp<AddIOp>()) {
495 // subi(addi(a, b), b) -> a
496 if (getRhs() == add.getRhs())
497 return add.getLhs();
498 // subi(addi(a, b), a) -> b
499 if (getRhs() == add.getLhs())
500 return add.getRhs();
501 }
502
504 adaptor.getOperands(),
505 [](APInt a, const APInt &b) { return std::move(a) - b; });
506}
507
508void arith::SubIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
509 MLIRContext *context) {
510 patterns.add<SubIRHSAddConstant, SubILHSAddConstant, SubIRHSSubConstantRHS,
511 SubIRHSSubConstantLHS, SubILHSSubConstantRHS,
512 SubILHSSubConstantLHS, SubISubILHSRHSLHS>(context);
513}
514
515//===----------------------------------------------------------------------===//
516// MulIOp
517//===----------------------------------------------------------------------===//
518
519OpFoldResult arith::MulIOp::fold(FoldAdaptor adaptor) {
520 // muli(x, 0) -> 0
521 if (matchPattern(adaptor.getRhs(), m_Zero()))
522 return getRhs();
523 // muli(x, 1) -> x
524 if (matchPattern(adaptor.getRhs(), m_One()))
525 return getLhs();
526 // TODO: Handle the overflow case.
527
528 // default folder
530 adaptor.getOperands(),
531 [](const APInt &a, const APInt &b) { return a * b; });
532}
533
534void arith::MulIOp::getAsmResultNames(
535 function_ref<void(Value, StringRef)> setNameFn) {
536 if (!isa<IndexType>(getType()))
537 return;
538
539 // Match vector.vscale by name to avoid depending on the vector dialect (which
540 // is a circular dependency).
541 auto isVscale = [](Operation *op) {
542 return op && op->getName().getStringRef() == "vector.vscale";
543 };
544
545 IntegerAttr baseValue;
546 auto isVscaleExpr = [&](Value a, Value b) {
547 return matchPattern(a, m_Constant(&baseValue)) &&
548 isVscale(b.getDefiningOp());
549 };
550
551 if (!isVscaleExpr(getLhs(), getRhs()) && !isVscaleExpr(getRhs(), getLhs()))
552 return;
553
554 // Name `base * vscale` or `vscale * base` as `c<base_value>_vscale`.
555 SmallString<32> specialNameBuffer;
556 llvm::raw_svector_ostream specialName(specialNameBuffer);
557 specialName << 'c' << baseValue.getInt() << "_vscale";
558 setNameFn(getResult(), specialName.str());
559}
560
561void arith::MulIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
562 MLIRContext *context) {
563 patterns.add<MulIMulIConstant>(context);
564}
565
566//===----------------------------------------------------------------------===//
567// MulSIExtendedOp
568//===----------------------------------------------------------------------===//
569
570std::optional<SmallVector<int64_t, 4>>
571arith::MulSIExtendedOp::getShapeForUnroll() {
572 if (auto vt = dyn_cast<VectorType>(getType(0)))
573 return llvm::to_vector<4>(vt.getShape());
574 return std::nullopt;
575}
576
577LogicalResult
578arith::MulSIExtendedOp::fold(FoldAdaptor adaptor,
579 SmallVectorImpl<OpFoldResult> &results) {
580 // mulsi_extended(x, 0) -> 0, 0
581 if (matchPattern(adaptor.getRhs(), m_Zero())) {
582 Attribute zero = adaptor.getRhs();
583 results.push_back(zero);
584 results.push_back(zero);
585 return success();
586 }
587
588 // mulsi_extended(cst_a, cst_b) -> cst_low, cst_high
589 if (Attribute lowAttr = constFoldBinaryOp<IntegerAttr>(
590 adaptor.getOperands(),
591 [](const APInt &a, const APInt &b) { return a * b; })) {
592 // Invoke the constant fold helper again to calculate the 'high' result.
593 Attribute highAttr = constFoldBinaryOp<IntegerAttr>(
594 adaptor.getOperands(), [](const APInt &a, const APInt &b) {
595 return llvm::APIntOps::mulhs(a, b);
596 });
597 assert(highAttr && "Unexpected constant-folding failure");
598
599 results.push_back(lowAttr);
600 results.push_back(highAttr);
601 return success();
602 }
603
604 return failure();
605}
606
607void arith::MulSIExtendedOp::getCanonicalizationPatterns(
608 RewritePatternSet &patterns, MLIRContext *context) {
609 patterns.add<MulSIExtendedToMulI, MulSIExtendedRHSOne>(context);
610}
611
612//===----------------------------------------------------------------------===//
613// MulUIExtendedOp
614//===----------------------------------------------------------------------===//
615
616std::optional<SmallVector<int64_t, 4>>
617arith::MulUIExtendedOp::getShapeForUnroll() {
618 if (auto vt = dyn_cast<VectorType>(getType(0)))
619 return llvm::to_vector<4>(vt.getShape());
620 return std::nullopt;
621}
622
623LogicalResult
624arith::MulUIExtendedOp::fold(FoldAdaptor adaptor,
625 SmallVectorImpl<OpFoldResult> &results) {
626 // mului_extended(x, 0) -> 0, 0
627 if (matchPattern(adaptor.getRhs(), m_Zero())) {
628 Attribute zero = adaptor.getRhs();
629 results.push_back(zero);
630 results.push_back(zero);
631 return success();
632 }
633
634 // mului_extended(x, 1) -> x, 0
635 if (matchPattern(adaptor.getRhs(), m_One())) {
636 Builder builder(getContext());
637 Attribute zero = builder.getZeroAttr(getLhs().getType());
638 results.push_back(getLhs());
639 results.push_back(zero);
640 return success();
641 }
642
643 // mului_extended(cst_a, cst_b) -> cst_low, cst_high
644 if (Attribute lowAttr = constFoldBinaryOp<IntegerAttr>(
645 adaptor.getOperands(),
646 [](const APInt &a, const APInt &b) { return a * b; })) {
647 // Invoke the constant fold helper again to calculate the 'high' result.
648 Attribute highAttr = constFoldBinaryOp<IntegerAttr>(
649 adaptor.getOperands(), [](const APInt &a, const APInt &b) {
650 return llvm::APIntOps::mulhu(a, b);
651 });
652 assert(highAttr && "Unexpected constant-folding failure");
653
654 results.push_back(lowAttr);
655 results.push_back(highAttr);
656 return success();
657 }
658
659 return failure();
660}
661
662void arith::MulUIExtendedOp::getCanonicalizationPatterns(
663 RewritePatternSet &patterns, MLIRContext *context) {
664 patterns.add<MulUIExtendedToMulI>(context);
665}
666
667//===----------------------------------------------------------------------===//
668// DivUIOp
669//===----------------------------------------------------------------------===//
670
671/// Fold `(a * b) / b -> a`
673 arith::IntegerOverflowFlags ovfFlags) {
674 auto mul = lhs.getDefiningOp<mlir::arith::MulIOp>();
675 if (!mul || !bitEnumContainsAll(mul.getOverflowFlags(), ovfFlags))
676 return {};
677
678 if (mul.getLhs() == rhs)
679 return mul.getRhs();
680
681 if (mul.getRhs() == rhs)
682 return mul.getLhs();
683
684 return {};
685}
686
687OpFoldResult arith::DivUIOp::fold(FoldAdaptor adaptor) {
688 // divui (x, 1) -> x.
689 if (matchPattern(adaptor.getRhs(), m_One()))
690 return getLhs();
691
692 // (a * b) / b -> a
693 if (Value val = foldDivMul(getLhs(), getRhs(), IntegerOverflowFlags::nuw))
694 return val;
695
696 // Don't fold if it would require a division by zero.
697 bool div0 = false;
698 auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
699 [&](APInt a, const APInt &b) {
700 if (div0 || !b) {
701 div0 = true;
702 return a;
703 }
704 return a.udiv(b);
705 });
706
707 return div0 ? Attribute() : result;
708}
709
710/// Returns whether an unsigned division by `divisor` is speculatable.
712 // X / 0 => UB
713 if (matchPattern(divisor, m_IntRangeWithoutZeroU()))
715
717}
718
719Speculation::Speculatability arith::DivUIOp::getSpeculatability() {
720 return getDivUISpeculatability(getRhs());
721}
722
723//===----------------------------------------------------------------------===//
724// DivSIOp
725//===----------------------------------------------------------------------===//
726
727OpFoldResult arith::DivSIOp::fold(FoldAdaptor adaptor) {
728 // divsi (x, 1) -> x.
729 if (matchPattern(adaptor.getRhs(), m_One()))
730 return getLhs();
731
732 // (a * b) / b -> a
733 if (Value val = foldDivMul(getLhs(), getRhs(), IntegerOverflowFlags::nsw))
734 return val;
735
736 // Don't fold if it would overflow or if it requires a division by zero.
737 bool overflowOrDiv0 = false;
739 adaptor.getOperands(), [&](APInt a, const APInt &b) {
740 if (overflowOrDiv0 || !b) {
741 overflowOrDiv0 = true;
742 return a;
743 }
744 return a.sdiv_ov(b, overflowOrDiv0);
745 });
746
747 return overflowOrDiv0 ? Attribute() : result;
748}
749
750/// Returns whether a signed division by `divisor` is speculatable. This
751/// function conservatively assumes that all signed division by -1 are not
752/// speculatable.
754 // X / 0 => UB
755 // INT_MIN / -1 => UB
756 if (matchPattern(divisor, m_IntRangeWithoutZeroS()) &&
759
761}
762
763Speculation::Speculatability arith::DivSIOp::getSpeculatability() {
764 return getDivSISpeculatability(getRhs());
765}
766
767//===----------------------------------------------------------------------===//
768// Ceil and floor division folding helpers
769//===----------------------------------------------------------------------===//
770
771static APInt signedCeilNonnegInputs(const APInt &a, const APInt &b,
772 bool &overflow) {
773 // Returns (a-1)/b + 1
774 APInt one(a.getBitWidth(), 1, true); // Signed value 1.
775 APInt val = a.ssub_ov(one, overflow).sdiv_ov(b, overflow);
776 return val.sadd_ov(one, overflow);
777}
778
779//===----------------------------------------------------------------------===//
780// CeilDivUIOp
781//===----------------------------------------------------------------------===//
782
783OpFoldResult arith::CeilDivUIOp::fold(FoldAdaptor adaptor) {
784 // ceildivui (x, 1) -> x.
785 if (matchPattern(adaptor.getRhs(), m_One()))
786 return getLhs();
787
788 bool overflowOrDiv0 = false;
790 adaptor.getOperands(), [&](APInt a, const APInt &b) {
791 if (overflowOrDiv0 || !b) {
792 overflowOrDiv0 = true;
793 return a;
794 }
795 APInt quotient = a.udiv(b);
796 if (!a.urem(b))
797 return quotient;
798 APInt one(a.getBitWidth(), 1, true);
799 return quotient.uadd_ov(one, overflowOrDiv0);
800 });
801
802 return overflowOrDiv0 ? Attribute() : result;
803}
804
805Speculation::Speculatability arith::CeilDivUIOp::getSpeculatability() {
806 return getDivUISpeculatability(getRhs());
807}
808
809//===----------------------------------------------------------------------===//
810// CeilDivSIOp
811//===----------------------------------------------------------------------===//
812
813OpFoldResult arith::CeilDivSIOp::fold(FoldAdaptor adaptor) {
814 // ceildivsi (x, 1) -> x.
815 if (matchPattern(adaptor.getRhs(), m_One()))
816 return getLhs();
817
818 // Don't fold if it would overflow or if it requires a division by zero.
819 // TODO: This hook won't fold operations where a = MININT, because
820 // negating MININT overflows. This can be improved.
821 bool overflowOrDiv0 = false;
823 adaptor.getOperands(), [&](APInt a, const APInt &b) {
824 if (overflowOrDiv0 || !b) {
825 overflowOrDiv0 = true;
826 return a;
827 }
828 if (!a)
829 return a;
830 // After this point we know that neither a or b are zero.
831 unsigned bits = a.getBitWidth();
832 APInt zero = APInt::getZero(bits);
833 bool aGtZero = a.sgt(zero);
834 bool bGtZero = b.sgt(zero);
835 if (aGtZero && bGtZero) {
836 // Both positive, return ceil(a, b).
837 return signedCeilNonnegInputs(a, b, overflowOrDiv0);
838 }
839
840 // No folding happens if any of the intermediate arithmetic operations
841 // overflows.
842 bool overflowNegA = false;
843 bool overflowNegB = false;
844 bool overflowDiv = false;
845 bool overflowNegRes = false;
846 if (!aGtZero && !bGtZero) {
847 // Both negative, return ceil(-a, -b).
848 APInt posA = zero.ssub_ov(a, overflowNegA);
849 APInt posB = zero.ssub_ov(b, overflowNegB);
850 APInt res = signedCeilNonnegInputs(posA, posB, overflowDiv);
851 overflowOrDiv0 = (overflowNegA || overflowNegB || overflowDiv);
852 return res;
853 }
854 if (!aGtZero && bGtZero) {
855 // A is negative, b is positive, return - ( -a / b).
856 APInt posA = zero.ssub_ov(a, overflowNegA);
857 APInt div = posA.sdiv_ov(b, overflowDiv);
858 APInt res = zero.ssub_ov(div, overflowNegRes);
859 overflowOrDiv0 = (overflowNegA || overflowDiv || overflowNegRes);
860 return res;
861 }
862 // A is positive, b is negative, return - (a / -b).
863 APInt posB = zero.ssub_ov(b, overflowNegB);
864 APInt div = a.sdiv_ov(posB, overflowDiv);
865 APInt res = zero.ssub_ov(div, overflowNegRes);
866
867 overflowOrDiv0 = (overflowNegB || overflowDiv || overflowNegRes);
868 return res;
869 });
870
871 return overflowOrDiv0 ? Attribute() : result;
872}
873
874Speculation::Speculatability arith::CeilDivSIOp::getSpeculatability() {
875 return getDivSISpeculatability(getRhs());
876}
877
878//===----------------------------------------------------------------------===//
879// FloorDivSIOp
880//===----------------------------------------------------------------------===//
881
882OpFoldResult arith::FloorDivSIOp::fold(FoldAdaptor adaptor) {
883 // floordivsi (x, 1) -> x.
884 if (matchPattern(adaptor.getRhs(), m_One()))
885 return getLhs();
886
887 // Don't fold if it would overflow or if it requires a division by zero.
888 bool overflowOrDiv = false;
890 adaptor.getOperands(), [&](APInt a, const APInt &b) {
891 if (b.isZero()) {
892 overflowOrDiv = true;
893 return a;
894 }
895 return a.sfloordiv_ov(b, overflowOrDiv);
896 });
897
898 return overflowOrDiv ? Attribute() : result;
899}
900
901//===----------------------------------------------------------------------===//
902// RemUIOp
903//===----------------------------------------------------------------------===//
904
905OpFoldResult arith::RemUIOp::fold(FoldAdaptor adaptor) {
906 // remui (x, 1) -> 0.
907 if (matchPattern(adaptor.getRhs(), m_One()))
908 return Builder(getContext()).getZeroAttr(getType());
909
910 // Don't fold if it would require a division by zero.
911 bool div0 = false;
912 auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
913 [&](APInt a, const APInt &b) {
914 if (div0 || b.isZero()) {
915 div0 = true;
916 return a;
917 }
918 return a.urem(b);
919 });
920
921 return div0 ? Attribute() : result;
922}
923
924//===----------------------------------------------------------------------===//
925// RemSIOp
926//===----------------------------------------------------------------------===//
927
928OpFoldResult arith::RemSIOp::fold(FoldAdaptor adaptor) {
929 // remsi (x, 1) -> 0.
930 if (matchPattern(adaptor.getRhs(), m_One()))
931 return Builder(getContext()).getZeroAttr(getType());
932
933 // Don't fold if it would require a division by zero.
934 bool div0 = false;
935 auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
936 [&](APInt a, const APInt &b) {
937 if (div0 || b.isZero()) {
938 div0 = true;
939 return a;
940 }
941 return a.srem(b);
942 });
943
944 return div0 ? Attribute() : result;
945}
946
947//===----------------------------------------------------------------------===//
948// AndIOp
949//===----------------------------------------------------------------------===//
950
951/// Fold `and(a, and(a, b))` to `and(a, b)`
952static Value foldAndIofAndI(arith::AndIOp op) {
953 for (bool reversePrev : {false, true}) {
954 auto prev = (reversePrev ? op.getRhs() : op.getLhs())
955 .getDefiningOp<arith::AndIOp>();
956 if (!prev)
957 continue;
958
959 Value other = (reversePrev ? op.getLhs() : op.getRhs());
960 if (other != prev.getLhs() && other != prev.getRhs())
961 continue;
962
963 return prev.getResult();
964 }
965 return {};
966}
967
968OpFoldResult arith::AndIOp::fold(FoldAdaptor adaptor) {
969 /// and(x, 0) -> 0
970 if (matchPattern(adaptor.getRhs(), m_Zero()))
971 return getRhs();
972 /// and(x, allOnes) -> x
973 APInt intValue;
974 if (matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue)) &&
975 intValue.isAllOnes())
976 return getLhs();
977 /// and(x, not(x)) -> 0
978 if (matchPattern(getRhs(), m_Op<XOrIOp>(matchers::m_Val(getLhs()),
979 m_ConstantInt(&intValue))) &&
980 intValue.isAllOnes())
981 return Builder(getContext()).getZeroAttr(getType());
982 /// and(not(x), x) -> 0
983 if (matchPattern(getLhs(), m_Op<XOrIOp>(matchers::m_Val(getRhs()),
984 m_ConstantInt(&intValue))) &&
985 intValue.isAllOnes())
986 return Builder(getContext()).getZeroAttr(getType());
987
988 /// and(a, and(a, b)) -> and(a, b)
989 if (Value result = foldAndIofAndI(*this))
990 return result;
991
993 adaptor.getOperands(),
994 [](APInt a, const APInt &b) { return std::move(a) & b; });
995}
996
997//===----------------------------------------------------------------------===//
998// OrIOp
999//===----------------------------------------------------------------------===//
1000
1001OpFoldResult arith::OrIOp::fold(FoldAdaptor adaptor) {
1002 if (APInt rhsVal; matchPattern(adaptor.getRhs(), m_ConstantInt(&rhsVal))) {
1003 /// or(x, 0) -> x
1004 if (rhsVal.isZero())
1005 return getLhs();
1006 /// or(x, <all ones>) -> <all ones>
1007 if (rhsVal.isAllOnes())
1008 return adaptor.getRhs();
1009 }
1010
1011 APInt intValue;
1012 /// or(x, xor(x, 1)) -> 1
1013 if (matchPattern(getRhs(), m_Op<XOrIOp>(matchers::m_Val(getLhs()),
1014 m_ConstantInt(&intValue))) &&
1015 intValue.isAllOnes())
1016 return getRhs().getDefiningOp<XOrIOp>().getRhs();
1017 /// or(xor(x, 1), x) -> 1
1018 if (matchPattern(getLhs(), m_Op<XOrIOp>(matchers::m_Val(getRhs()),
1019 m_ConstantInt(&intValue))) &&
1020 intValue.isAllOnes())
1021 return getLhs().getDefiningOp<XOrIOp>().getRhs();
1022
1024 adaptor.getOperands(),
1025 [](APInt a, const APInt &b) { return std::move(a) | b; });
1026}
1027
1028//===----------------------------------------------------------------------===//
1029// XOrIOp
1030//===----------------------------------------------------------------------===//
1031
1032OpFoldResult arith::XOrIOp::fold(FoldAdaptor adaptor) {
1033 /// xor(x, 0) -> x
1034 if (matchPattern(adaptor.getRhs(), m_Zero()))
1035 return getLhs();
1036 /// xor(x, x) -> 0
1037 if (getLhs() == getRhs())
1038 return Builder(getContext()).getZeroAttr(getType());
1039 /// xor(xor(x, a), a) -> x
1040 /// xor(xor(a, x), a) -> x
1041 if (arith::XOrIOp prev = getLhs().getDefiningOp<arith::XOrIOp>()) {
1042 if (prev.getRhs() == getRhs())
1043 return prev.getLhs();
1044 if (prev.getLhs() == getRhs())
1045 return prev.getRhs();
1046 }
1047 /// xor(a, xor(x, a)) -> x
1048 /// xor(a, xor(a, x)) -> x
1049 if (arith::XOrIOp prev = getRhs().getDefiningOp<arith::XOrIOp>()) {
1050 if (prev.getRhs() == getLhs())
1051 return prev.getLhs();
1052 if (prev.getLhs() == getLhs())
1053 return prev.getRhs();
1054 }
1055
1057 adaptor.getOperands(),
1058 [](APInt a, const APInt &b) { return std::move(a) ^ b; });
1059}
1060
1061void arith::XOrIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1062 MLIRContext *context) {
1063 patterns.add<XOrINotCmpI, XOrIOfExtUI, XOrIOfExtSI>(context);
1064}
1065
1066//===----------------------------------------------------------------------===//
1067// NegFOp
1068//===----------------------------------------------------------------------===//
1069
1070OpFoldResult arith::NegFOp::fold(FoldAdaptor adaptor) {
1071 /// negf(negf(x)) -> x
1072 if (auto op = this->getOperand().getDefiningOp<arith::NegFOp>())
1073 return op.getOperand();
1074 return constFoldUnaryOp<FloatAttr>(adaptor.getOperands(),
1075 [](const APFloat &a) { return -a; });
1076}
1077
1078//===----------------------------------------------------------------------===//
1079// AddFOp
1080//===----------------------------------------------------------------------===//
1081
1082OpFoldResult arith::AddFOp::fold(FoldAdaptor adaptor) {
1083 // addf(x, -0) -> x
1084 if (matchPattern(adaptor.getRhs(), m_NegZeroFloat()))
1085 return getLhs();
1086
1088 adaptor.getOperands(),
1089 [](const APFloat &a, const APFloat &b) { return a + b; });
1090}
1091
1092//===----------------------------------------------------------------------===//
1093// SubFOp
1094//===----------------------------------------------------------------------===//
1095
1096OpFoldResult arith::SubFOp::fold(FoldAdaptor adaptor) {
1097 // subf(x, +0) -> x
1098 if (matchPattern(adaptor.getRhs(), m_PosZeroFloat()))
1099 return getLhs();
1100
1102 adaptor.getOperands(),
1103 [](const APFloat &a, const APFloat &b) { return a - b; });
1104}
1105
1106//===----------------------------------------------------------------------===//
1107// MaximumFOp
1108//===----------------------------------------------------------------------===//
1109
1110OpFoldResult arith::MaximumFOp::fold(FoldAdaptor adaptor) {
1111 // maximumf(x,x) -> x
1112 if (getLhs() == getRhs())
1113 return getRhs();
1114
1115 // maximumf(x, -inf) -> x
1116 if (matchPattern(adaptor.getRhs(), m_NegInfFloat()))
1117 return getLhs();
1118
1120 adaptor.getOperands(),
1121 [](const APFloat &a, const APFloat &b) { return llvm::maximum(a, b); });
1122}
1123
1124//===----------------------------------------------------------------------===//
1125// MaxNumFOp
1126//===----------------------------------------------------------------------===//
1127
1128OpFoldResult arith::MaxNumFOp::fold(FoldAdaptor adaptor) {
1129 // maxnumf(x,x) -> x
1130 if (getLhs() == getRhs())
1131 return getRhs();
1132
1133 // maxnumf(x, NaN) -> x
1134 if (matchPattern(adaptor.getRhs(), m_NaNFloat()))
1135 return getLhs();
1136
1137 return constFoldBinaryOp<FloatAttr>(adaptor.getOperands(), llvm::maxnum);
1138}
1139
1140//===----------------------------------------------------------------------===//
1141// MaxSIOp
1142//===----------------------------------------------------------------------===//
1143
1144OpFoldResult MaxSIOp::fold(FoldAdaptor adaptor) {
1145 // maxsi(x,x) -> x
1146 if (getLhs() == getRhs())
1147 return getRhs();
1148
1149 if (APInt intValue;
1150 matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue))) {
1151 // maxsi(x,MAX_INT) -> MAX_INT
1152 if (intValue.isMaxSignedValue())
1153 return getRhs();
1154 // maxsi(x, MIN_INT) -> x
1155 if (intValue.isMinSignedValue())
1156 return getLhs();
1157 }
1158
1159 return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1160 [](const APInt &a, const APInt &b) {
1161 return llvm::APIntOps::smax(a, b);
1162 });
1163}
1164
1165//===----------------------------------------------------------------------===//
1166// MaxUIOp
1167//===----------------------------------------------------------------------===//
1168
1169OpFoldResult MaxUIOp::fold(FoldAdaptor adaptor) {
1170 // maxui(x,x) -> x
1171 if (getLhs() == getRhs())
1172 return getRhs();
1173
1174 if (APInt intValue;
1175 matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue))) {
1176 // maxui(x,MAX_INT) -> MAX_INT
1177 if (intValue.isMaxValue())
1178 return getRhs();
1179 // maxui(x, MIN_INT) -> x
1180 if (intValue.isMinValue())
1181 return getLhs();
1182 }
1183
1184 return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1185 [](const APInt &a, const APInt &b) {
1186 return llvm::APIntOps::umax(a, b);
1187 });
1188}
1189
1190//===----------------------------------------------------------------------===//
1191// MinimumFOp
1192//===----------------------------------------------------------------------===//
1193
1194OpFoldResult arith::MinimumFOp::fold(FoldAdaptor adaptor) {
1195 // minimumf(x,x) -> x
1196 if (getLhs() == getRhs())
1197 return getRhs();
1198
1199 // minimumf(x, +inf) -> x
1200 if (matchPattern(adaptor.getRhs(), m_PosInfFloat()))
1201 return getLhs();
1202
1204 adaptor.getOperands(),
1205 [](const APFloat &a, const APFloat &b) { return llvm::minimum(a, b); });
1206}
1207
1208//===----------------------------------------------------------------------===//
1209// MinNumFOp
1210//===----------------------------------------------------------------------===//
1211
1212OpFoldResult arith::MinNumFOp::fold(FoldAdaptor adaptor) {
1213 // minnumf(x,x) -> x
1214 if (getLhs() == getRhs())
1215 return getRhs();
1216
1217 // minnumf(x, NaN) -> x
1218 if (matchPattern(adaptor.getRhs(), m_NaNFloat()))
1219 return getLhs();
1220
1222 adaptor.getOperands(),
1223 [](const APFloat &a, const APFloat &b) { return llvm::minnum(a, b); });
1224}
1225
1226//===----------------------------------------------------------------------===//
1227// MinSIOp
1228//===----------------------------------------------------------------------===//
1229
1230OpFoldResult MinSIOp::fold(FoldAdaptor adaptor) {
1231 // minsi(x,x) -> x
1232 if (getLhs() == getRhs())
1233 return getRhs();
1234
1235 if (APInt intValue;
1236 matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue))) {
1237 // minsi(x,MIN_INT) -> MIN_INT
1238 if (intValue.isMinSignedValue())
1239 return getRhs();
1240 // minsi(x, MAX_INT) -> x
1241 if (intValue.isMaxSignedValue())
1242 return getLhs();
1243 }
1244
1245 return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1246 [](const APInt &a, const APInt &b) {
1247 return llvm::APIntOps::smin(a, b);
1248 });
1249}
1250
1251//===----------------------------------------------------------------------===//
1252// MinUIOp
1253//===----------------------------------------------------------------------===//
1254
1255OpFoldResult MinUIOp::fold(FoldAdaptor adaptor) {
1256 // minui(x,x) -> x
1257 if (getLhs() == getRhs())
1258 return getRhs();
1259
1260 if (APInt intValue;
1261 matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue))) {
1262 // minui(x,MIN_INT) -> MIN_INT
1263 if (intValue.isMinValue())
1264 return getRhs();
1265 // minui(x, MAX_INT) -> x
1266 if (intValue.isMaxValue())
1267 return getLhs();
1268 }
1269
1270 return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
1271 [](const APInt &a, const APInt &b) {
1272 return llvm::APIntOps::umin(a, b);
1273 });
1274}
1275
1276//===----------------------------------------------------------------------===//
1277// MulFOp
1278//===----------------------------------------------------------------------===//
1279
1280OpFoldResult arith::MulFOp::fold(FoldAdaptor adaptor) {
1281 // mulf(x, 1) -> x
1282 if (matchPattern(adaptor.getRhs(), m_OneFloat()))
1283 return getLhs();
1284
1285 if (arith::bitEnumContainsAll(getFastmath(), arith::FastMathFlags::nnan |
1286 arith::FastMathFlags::nsz)) {
1287 // mulf(x, 0) -> 0
1288 if (matchPattern(adaptor.getRhs(), m_AnyZeroFloat()))
1289 return getRhs();
1290 }
1291
1293 adaptor.getOperands(),
1294 [](const APFloat &a, const APFloat &b) { return a * b; });
1295}
1296
1297void arith::MulFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1298 MLIRContext *context) {
1299 patterns.add<MulFOfNegF>(context);
1300}
1301
1302//===----------------------------------------------------------------------===//
1303// DivFOp
1304//===----------------------------------------------------------------------===//
1305
1306OpFoldResult arith::DivFOp::fold(FoldAdaptor adaptor) {
1307 // divf(x, 1) -> x
1308 if (matchPattern(adaptor.getRhs(), m_OneFloat()))
1309 return getLhs();
1310
1312 adaptor.getOperands(),
1313 [](const APFloat &a, const APFloat &b) { return a / b; });
1314}
1315
1316void arith::DivFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1317 MLIRContext *context) {
1318 patterns.add<DivFOfNegF>(context);
1319}
1320
1321//===----------------------------------------------------------------------===//
1322// RemFOp
1323//===----------------------------------------------------------------------===//
1324
1325OpFoldResult arith::RemFOp::fold(FoldAdaptor adaptor) {
1326 return constFoldBinaryOp<FloatAttr>(adaptor.getOperands(),
1327 [](const APFloat &a, const APFloat &b) {
1328 APFloat result(a);
1329 // APFloat::mod() offers the remainder
1330 // behavior we want, i.e. the result has
1331 // the sign of LHS operand.
1332 (void)result.mod(b);
1333 return result;
1334 });
1335}
1336
1337//===----------------------------------------------------------------------===//
1338// Utility functions for verifying cast ops
1339//===----------------------------------------------------------------------===//
1340
1341template <typename... Types>
1342using type_list = std::tuple<Types...> *;
1343
1344/// Returns a non-null type only if the provided type is one of the allowed
1345/// types or one of the allowed shaped types of the allowed types. Returns the
1346/// element type if a valid shaped type is provided.
1347template <typename... ShapedTypes, typename... ElementTypes>
1350 if (llvm::isa<ShapedType>(type) && !llvm::isa<ShapedTypes...>(type))
1351 return {};
1352
1353 auto underlyingType = getElementTypeOrSelf(type);
1354 if (!llvm::isa<ElementTypes...>(underlyingType))
1355 return {};
1356
1357 return underlyingType;
1358}
1359
1360/// Get allowed underlying types for vectors and tensors.
1361template <typename... ElementTypes>
1366
1367/// Get allowed underlying types for vectors, tensors, and memrefs.
1368template <typename... ElementTypes>
1374
1375/// Return false if both types are ranked tensor with mismatching encoding.
1376static bool hasSameEncoding(Type typeA, Type typeB) {
1377 auto rankedTensorA = dyn_cast<RankedTensorType>(typeA);
1378 auto rankedTensorB = dyn_cast<RankedTensorType>(typeB);
1379 if (!rankedTensorA || !rankedTensorB)
1380 return true;
1381 return rankedTensorA.getEncoding() == rankedTensorB.getEncoding();
1382}
1383
1385 if (inputs.size() != 1 || outputs.size() != 1)
1386 return false;
1387 if (!hasSameEncoding(inputs.front(), outputs.front()))
1388 return false;
1389 return succeeded(verifyCompatibleShapes(inputs.front(), outputs.front()));
1390}
1391
1392//===----------------------------------------------------------------------===//
1393// Verifiers for integer and floating point extension/truncation ops
1394//===----------------------------------------------------------------------===//
1395
1396// Extend ops can only extend to a wider type.
1397template <typename ValType, typename Op>
1398static LogicalResult verifyExtOp(Op op) {
1399 Type srcType = getElementTypeOrSelf(op.getIn().getType());
1400 Type dstType = getElementTypeOrSelf(op.getType());
1401
1402 if (llvm::cast<ValType>(srcType).getWidth() >=
1403 llvm::cast<ValType>(dstType).getWidth())
1404 return op.emitError("result type ")
1405 << dstType << " must be wider than operand type " << srcType;
1406
1407 return success();
1408}
1409
1410// Truncate ops can only truncate to a shorter type.
1411template <typename ValType, typename Op>
1412static LogicalResult verifyTruncateOp(Op op) {
1413 Type srcType = getElementTypeOrSelf(op.getIn().getType());
1414 Type dstType = getElementTypeOrSelf(op.getType());
1415
1416 if (llvm::cast<ValType>(srcType).getWidth() <=
1417 llvm::cast<ValType>(dstType).getWidth())
1418 return op.emitError("result type ")
1419 << dstType << " must be shorter than operand type " << srcType;
1420
1421 return success();
1422}
1423
1424/// Validate a cast that changes the width of a type.
1425template <template <typename> class WidthComparator, typename... ElementTypes>
1426static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs) {
1427 if (!areValidCastInputsAndOutputs(inputs, outputs))
1428 return false;
1429
1430 auto srcType = getTypeIfLike<ElementTypes...>(inputs.front());
1431 auto dstType = getTypeIfLike<ElementTypes...>(outputs.front());
1432 if (!srcType || !dstType)
1433 return false;
1434
1435 return WidthComparator<unsigned>()(dstType.getIntOrFloatBitWidth(),
1436 srcType.getIntOrFloatBitWidth());
1437}
1438
1439/// Attempts to convert `sourceValue` to an APFloat value with
1440/// `targetSemantics` and `roundingMode`, without any information loss.
1441static FailureOr<APFloat> convertFloatValue(
1442 APFloat sourceValue, const llvm::fltSemantics &targetSemantics,
1443 llvm::RoundingMode roundingMode = llvm::RoundingMode::NearestTiesToEven) {
1444 bool losesInfo = false;
1445 auto status = sourceValue.convert(targetSemantics, roundingMode, &losesInfo);
1446 if (losesInfo || status != APFloat::opOK)
1447 return failure();
1448
1449 return sourceValue;
1450}
1451
1452//===----------------------------------------------------------------------===//
1453// ExtUIOp
1454//===----------------------------------------------------------------------===//
1455
1456OpFoldResult arith::ExtUIOp::fold(FoldAdaptor adaptor) {
1457 if (auto lhs = getIn().getDefiningOp<ExtUIOp>()) {
1458 getInMutable().assign(lhs.getIn());
1459 return getResult();
1460 }
1461
1462 Type resType = getElementTypeOrSelf(getType());
1463 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1465 adaptor.getOperands(), getType(),
1466 [bitWidth](const APInt &a, bool &castStatus) {
1467 return a.zext(bitWidth);
1468 });
1469}
1470
1471bool arith::ExtUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1473}
1474
1475LogicalResult arith::ExtUIOp::verify() {
1476 return verifyExtOp<IntegerType>(*this);
1477}
1478
1479//===----------------------------------------------------------------------===//
1480// ExtSIOp
1481//===----------------------------------------------------------------------===//
1482
1483OpFoldResult arith::ExtSIOp::fold(FoldAdaptor adaptor) {
1484 if (auto lhs = getIn().getDefiningOp<ExtSIOp>()) {
1485 getInMutable().assign(lhs.getIn());
1486 return getResult();
1487 }
1488
1489 Type resType = getElementTypeOrSelf(getType());
1490 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1492 adaptor.getOperands(), getType(),
1493 [bitWidth](const APInt &a, bool &castStatus) {
1494 return a.sext(bitWidth);
1495 });
1496}
1497
1498bool arith::ExtSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1500}
1501
1502void arith::ExtSIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1503 MLIRContext *context) {
1504 patterns.add<ExtSIOfExtUI>(context);
1505}
1506
1507LogicalResult arith::ExtSIOp::verify() {
1508 return verifyExtOp<IntegerType>(*this);
1509}
1510
1511//===----------------------------------------------------------------------===//
1512// ExtFOp
1513//===----------------------------------------------------------------------===//
1514
1515/// Fold extension of float constants when there is no information loss due the
1516/// difference in fp semantics.
1517OpFoldResult arith::ExtFOp::fold(FoldAdaptor adaptor) {
1518 if (auto truncFOp = getOperand().getDefiningOp<TruncFOp>()) {
1519 if (truncFOp.getOperand().getType() == getType()) {
1520 arith::FastMathFlags truncFMF =
1521 truncFOp.getFastmath().value_or(arith::FastMathFlags::none);
1522 bool isTruncContract =
1523 bitEnumContainsAll(truncFMF, arith::FastMathFlags::contract);
1524 arith::FastMathFlags extFMF =
1525 getFastmath().value_or(arith::FastMathFlags::none);
1526 bool isExtContract =
1527 bitEnumContainsAll(extFMF, arith::FastMathFlags::contract);
1528 if (isTruncContract && isExtContract) {
1529 return truncFOp.getOperand();
1530 }
1531 }
1532 }
1533
1534 auto resElemType = cast<FloatType>(getElementTypeOrSelf(getType()));
1535 const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
1537 adaptor.getOperands(), getType(),
1538 [&targetSemantics](const APFloat &a, bool &castStatus) {
1539 FailureOr<APFloat> result = convertFloatValue(a, targetSemantics);
1540 if (failed(result)) {
1541 castStatus = false;
1542 return a;
1543 }
1544 return *result;
1545 });
1546}
1547
1548bool arith::ExtFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1549 return checkWidthChangeCast<std::greater, FloatType>(inputs, outputs);
1550}
1551
1552LogicalResult arith::ExtFOp::verify() { return verifyExtOp<FloatType>(*this); }
1553
1554//===----------------------------------------------------------------------===//
1555// ScalingExtFOp
1556//===----------------------------------------------------------------------===//
1557
1558bool arith::ScalingExtFOp::areCastCompatible(TypeRange inputs,
1559 TypeRange outputs) {
1560 return checkWidthChangeCast<std::greater, FloatType>(inputs.front(), outputs);
1561}
1562
1563LogicalResult arith::ScalingExtFOp::verify() {
1564 return verifyExtOp<FloatType>(*this);
1565}
1566
1567//===----------------------------------------------------------------------===//
1568// TruncIOp
1569//===----------------------------------------------------------------------===//
1570
1571OpFoldResult arith::TruncIOp::fold(FoldAdaptor adaptor) {
1572 if (matchPattern(getOperand(), m_Op<arith::ExtUIOp>()) ||
1573 matchPattern(getOperand(), m_Op<arith::ExtSIOp>())) {
1574 Value src = getOperand().getDefiningOp()->getOperand(0);
1575 Type srcType = getElementTypeOrSelf(src.getType());
1576 Type dstType = getElementTypeOrSelf(getType());
1577 // trunci(zexti(a)) -> trunci(a)
1578 // trunci(sexti(a)) -> trunci(a)
1579 if (llvm::cast<IntegerType>(srcType).getWidth() >
1580 llvm::cast<IntegerType>(dstType).getWidth()) {
1581 setOperand(src);
1582 return getResult();
1583 }
1584
1585 // trunci(zexti(a)) -> a
1586 // trunci(sexti(a)) -> a
1587 if (srcType == dstType)
1588 return src;
1589 }
1590
1591 // trunci(trunci(a)) -> trunci(a))
1592 if (matchPattern(getOperand(), m_Op<arith::TruncIOp>())) {
1593 setOperand(getOperand().getDefiningOp()->getOperand(0));
1594 return getResult();
1595 }
1596
1597 Type resType = getElementTypeOrSelf(getType());
1598 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1600 adaptor.getOperands(), getType(),
1601 [bitWidth](const APInt &a, bool &castStatus) {
1602 return a.trunc(bitWidth);
1603 });
1604}
1605
1606bool arith::TruncIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1607 return checkWidthChangeCast<std::less, IntegerType>(inputs, outputs);
1608}
1609
1610void arith::TruncIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1611 MLIRContext *context) {
1612 patterns
1613 .add<TruncIExtSIToExtSI, TruncIExtUIToExtUI, TruncIShrSIToTrunciShrUI>(
1614 context);
1615}
1616
1617LogicalResult arith::TruncIOp::verify() {
1618 return verifyTruncateOp<IntegerType>(*this);
1619}
1620
1621//===----------------------------------------------------------------------===//
1622// TruncFOp
1623//===----------------------------------------------------------------------===//
1624
1625/// Perform safe const propagation for truncf, i.e., only propagate if FP value
1626/// can be represented without precision loss.
1627OpFoldResult arith::TruncFOp::fold(FoldAdaptor adaptor) {
1628 auto resElemType = cast<FloatType>(getElementTypeOrSelf(getType()));
1629 if (auto extOp = getOperand().getDefiningOp<arith::ExtFOp>()) {
1630 Value src = extOp.getIn();
1631 auto srcType = cast<FloatType>(getElementTypeOrSelf(src.getType()));
1632 auto intermediateType =
1633 cast<FloatType>(getElementTypeOrSelf(extOp.getType()));
1634 // Check if the srcType is representable in the intermediateType.
1635 if (llvm::APFloatBase::isRepresentableBy(
1636 srcType.getFloatSemantics(),
1637 intermediateType.getFloatSemantics())) {
1638 // truncf(extf(a)) -> truncf(a)
1639 if (srcType.getWidth() > resElemType.getWidth()) {
1640 setOperand(src);
1641 return getResult();
1642 }
1643
1644 // truncf(extf(a)) -> a
1645 if (srcType == resElemType)
1646 return src;
1647 }
1648 }
1649
1650 const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
1652 adaptor.getOperands(), getType(),
1653 [this, &targetSemantics](const APFloat &a, bool &castStatus) {
1654 RoundingMode roundingMode =
1655 getRoundingmode().value_or(RoundingMode::to_nearest_even);
1656 llvm::RoundingMode llvmRoundingMode =
1658 FailureOr<APFloat> result =
1659 convertFloatValue(a, targetSemantics, llvmRoundingMode);
1660 if (failed(result)) {
1661 castStatus = false;
1662 return a;
1663 }
1664 return *result;
1665 });
1666}
1667
1668void arith::TruncFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1669 MLIRContext *context) {
1670 patterns.add<TruncFSIToFPToSIToFP, TruncFUIToFPToUIToFP>(context);
1671}
1672
1673bool arith::TruncFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1674 return checkWidthChangeCast<std::less, FloatType>(inputs, outputs);
1675}
1676
1677LogicalResult arith::TruncFOp::verify() {
1678 return verifyTruncateOp<FloatType>(*this);
1679}
1680
1681//===----------------------------------------------------------------------===//
1682// ScalingTruncFOp
1683//===----------------------------------------------------------------------===//
1684
1685bool arith::ScalingTruncFOp::areCastCompatible(TypeRange inputs,
1686 TypeRange outputs) {
1687 return checkWidthChangeCast<std::less, FloatType>(inputs.front(), outputs);
1688}
1689
1690LogicalResult arith::ScalingTruncFOp::verify() {
1691 return verifyTruncateOp<FloatType>(*this);
1692}
1693
1694//===----------------------------------------------------------------------===//
1695// AndIOp
1696//===----------------------------------------------------------------------===//
1697
1698void arith::AndIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1699 MLIRContext *context) {
1700 patterns.add<AndOfExtUI, AndOfExtSI>(context);
1701}
1702
1703//===----------------------------------------------------------------------===//
1704// OrIOp
1705//===----------------------------------------------------------------------===//
1706
1707void arith::OrIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1708 MLIRContext *context) {
1709 patterns.add<OrOfExtUI, OrOfExtSI>(context);
1710}
1711
1712//===----------------------------------------------------------------------===//
1713// Verifiers for casts between integers and floats.
1714//===----------------------------------------------------------------------===//
1715
1716template <typename From, typename To>
1717static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs) {
1718 if (!areValidCastInputsAndOutputs(inputs, outputs))
1719 return false;
1720
1721 auto srcType = getTypeIfLike<From>(inputs.front());
1722 auto dstType = getTypeIfLike<To>(outputs.back());
1723
1724 return srcType && dstType;
1725}
1726
1727//===----------------------------------------------------------------------===//
1728// UIToFPOp
1729//===----------------------------------------------------------------------===//
1730
1731bool arith::UIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1732 return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
1733}
1734
1735OpFoldResult arith::UIToFPOp::fold(FoldAdaptor adaptor) {
1736 Type resEleType = getElementTypeOrSelf(getType());
1738 adaptor.getOperands(), getType(),
1739 [&resEleType](const APInt &a, bool &castStatus) {
1740 FloatType floatTy = llvm::cast<FloatType>(resEleType);
1741 APFloat apf(floatTy.getFloatSemantics(),
1742 APInt::getZero(floatTy.getWidth()));
1743 apf.convertFromAPInt(a, /*IsSigned=*/false,
1744 APFloat::rmNearestTiesToEven);
1745 return apf;
1746 });
1747}
1748
1749//===----------------------------------------------------------------------===//
1750// SIToFPOp
1751//===----------------------------------------------------------------------===//
1752
1753bool arith::SIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1754 return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
1755}
1756
1757OpFoldResult arith::SIToFPOp::fold(FoldAdaptor adaptor) {
1758 Type resEleType = getElementTypeOrSelf(getType());
1760 adaptor.getOperands(), getType(),
1761 [&resEleType](const APInt &a, bool &castStatus) {
1762 FloatType floatTy = llvm::cast<FloatType>(resEleType);
1763 APFloat apf(floatTy.getFloatSemantics(),
1764 APInt::getZero(floatTy.getWidth()));
1765 apf.convertFromAPInt(a, /*IsSigned=*/true,
1766 APFloat::rmNearestTiesToEven);
1767 return apf;
1768 });
1769}
1770
1771//===----------------------------------------------------------------------===//
1772// FPToUIOp
1773//===----------------------------------------------------------------------===//
1774
1775bool arith::FPToUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1776 return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
1777}
1778
1779OpFoldResult arith::FPToUIOp::fold(FoldAdaptor adaptor) {
1780 Type resType = getElementTypeOrSelf(getType());
1781 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1783 adaptor.getOperands(), getType(),
1784 [&bitWidth](const APFloat &a, bool &castStatus) {
1785 bool ignored;
1786 APSInt api(bitWidth, /*isUnsigned=*/true);
1787 castStatus = APFloat::opInvalidOp !=
1788 a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
1789 return api;
1790 });
1791}
1792
1793//===----------------------------------------------------------------------===//
1794// FPToSIOp
1795//===----------------------------------------------------------------------===//
1796
1797bool arith::FPToSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1798 return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
1799}
1800
1801OpFoldResult arith::FPToSIOp::fold(FoldAdaptor adaptor) {
1802 Type resType = getElementTypeOrSelf(getType());
1803 unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
1805 adaptor.getOperands(), getType(),
1806 [&bitWidth](const APFloat &a, bool &castStatus) {
1807 bool ignored;
1808 APSInt api(bitWidth, /*isUnsigned=*/false);
1809 castStatus = APFloat::opInvalidOp !=
1810 a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
1811 return api;
1812 });
1813}
1814
1815//===----------------------------------------------------------------------===//
1816// IndexCastOp
1817//===----------------------------------------------------------------------===//
1818
1819static bool areIndexCastCompatible(TypeRange inputs, TypeRange outputs) {
1820 if (!areValidCastInputsAndOutputs(inputs, outputs))
1821 return false;
1822
1823 auto srcType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(inputs.front());
1824 auto dstType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(outputs.front());
1825 if (!srcType || !dstType)
1826 return false;
1827
1828 return (srcType.isIndex() && dstType.isSignlessInteger()) ||
1829 (srcType.isSignlessInteger() && dstType.isIndex());
1830}
1831
1832bool arith::IndexCastOp::areCastCompatible(TypeRange inputs,
1833 TypeRange outputs) {
1834 return areIndexCastCompatible(inputs, outputs);
1835}
1836
1837OpFoldResult arith::IndexCastOp::fold(FoldAdaptor adaptor) {
1838 // index_cast(constant) -> constant
1839 unsigned resultBitwidth = 64; // Default for index integer attributes.
1840 if (auto intTy = dyn_cast<IntegerType>(getElementTypeOrSelf(getType())))
1841 resultBitwidth = intTy.getWidth();
1842
1844 adaptor.getOperands(), getType(),
1845 [resultBitwidth](const APInt &a, bool & /*castStatus*/) {
1846 return a.sextOrTrunc(resultBitwidth);
1847 });
1848}
1849
1850void arith::IndexCastOp::getCanonicalizationPatterns(
1851 RewritePatternSet &patterns, MLIRContext *context) {
1852 patterns.add<IndexCastOfIndexCast, IndexCastOfExtSI>(context);
1853}
1854
1855//===----------------------------------------------------------------------===//
1856// IndexCastUIOp
1857//===----------------------------------------------------------------------===//
1858
1859bool arith::IndexCastUIOp::areCastCompatible(TypeRange inputs,
1860 TypeRange outputs) {
1861 return areIndexCastCompatible(inputs, outputs);
1862}
1863
1864OpFoldResult arith::IndexCastUIOp::fold(FoldAdaptor adaptor) {
1865 // index_castui(constant) -> constant
1866 unsigned resultBitwidth = 64; // Default for index integer attributes.
1867 if (auto intTy = dyn_cast<IntegerType>(getElementTypeOrSelf(getType())))
1868 resultBitwidth = intTy.getWidth();
1869
1871 adaptor.getOperands(), getType(),
1872 [resultBitwidth](const APInt &a, bool & /*castStatus*/) {
1873 return a.zextOrTrunc(resultBitwidth);
1874 });
1875}
1876
1877void arith::IndexCastUIOp::getCanonicalizationPatterns(
1878 RewritePatternSet &patterns, MLIRContext *context) {
1879 patterns.add<IndexCastUIOfIndexCastUI, IndexCastUIOfExtUI>(context);
1880}
1881
1882//===----------------------------------------------------------------------===//
1883// BitcastOp
1884//===----------------------------------------------------------------------===//
1885
1886bool arith::BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1887 if (!areValidCastInputsAndOutputs(inputs, outputs))
1888 return false;
1889
1890 auto srcType = getTypeIfLikeOrMemRef<IntegerType, FloatType>(inputs.front());
1891 auto dstType = getTypeIfLikeOrMemRef<IntegerType, FloatType>(outputs.front());
1892 if (!srcType || !dstType)
1893 return false;
1894
1895 return srcType.getIntOrFloatBitWidth() == dstType.getIntOrFloatBitWidth();
1896}
1897
1898OpFoldResult arith::BitcastOp::fold(FoldAdaptor adaptor) {
1899 auto resType = getType();
1900 auto operand = adaptor.getIn();
1901 if (!operand)
1902 return {};
1903
1904 /// Bitcast dense elements.
1905 if (auto denseAttr = dyn_cast_or_null<DenseElementsAttr>(operand))
1906 return denseAttr.bitcast(llvm::cast<ShapedType>(resType).getElementType());
1907 /// Other shaped types unhandled.
1908 if (llvm::isa<ShapedType>(resType))
1909 return {};
1910
1911 /// Bitcast poison.
1912 if (llvm::isa<ub::PoisonAttr>(operand))
1913 return ub::PoisonAttr::get(getContext());
1914
1915 /// Bitcast integer or float to integer or float.
1916 APInt bits = llvm::isa<FloatAttr>(operand)
1917 ? llvm::cast<FloatAttr>(operand).getValue().bitcastToAPInt()
1918 : llvm::cast<IntegerAttr>(operand).getValue();
1919 assert(resType.getIntOrFloatBitWidth() == bits.getBitWidth() &&
1920 "trying to fold on broken IR: operands have incompatible types");
1921
1922 if (auto resFloatType = dyn_cast<FloatType>(resType))
1923 return FloatAttr::get(resType,
1924 APFloat(resFloatType.getFloatSemantics(), bits));
1925 return IntegerAttr::get(resType, bits);
1926}
1927
1928void arith::BitcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1929 MLIRContext *context) {
1930 patterns.add<BitcastOfBitcast>(context);
1931}
1932
1933//===----------------------------------------------------------------------===//
1934// CmpIOp
1935//===----------------------------------------------------------------------===//
1936
1937/// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer
1938/// comparison predicates.
1939bool mlir::arith::applyCmpPredicate(arith::CmpIPredicate predicate,
1940 const APInt &lhs, const APInt &rhs) {
1941 switch (predicate) {
1942 case arith::CmpIPredicate::eq:
1943 return lhs.eq(rhs);
1944 case arith::CmpIPredicate::ne:
1945 return lhs.ne(rhs);
1946 case arith::CmpIPredicate::slt:
1947 return lhs.slt(rhs);
1948 case arith::CmpIPredicate::sle:
1949 return lhs.sle(rhs);
1950 case arith::CmpIPredicate::sgt:
1951 return lhs.sgt(rhs);
1952 case arith::CmpIPredicate::sge:
1953 return lhs.sge(rhs);
1954 case arith::CmpIPredicate::ult:
1955 return lhs.ult(rhs);
1956 case arith::CmpIPredicate::ule:
1957 return lhs.ule(rhs);
1958 case arith::CmpIPredicate::ugt:
1959 return lhs.ugt(rhs);
1960 case arith::CmpIPredicate::uge:
1961 return lhs.uge(rhs);
1962 }
1963 llvm_unreachable("unknown cmpi predicate kind");
1964}
1965
1966/// Returns true if the predicate is true for two equal operands.
1967static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate) {
1968 switch (predicate) {
1969 case arith::CmpIPredicate::eq:
1970 case arith::CmpIPredicate::sle:
1971 case arith::CmpIPredicate::sge:
1972 case arith::CmpIPredicate::ule:
1973 case arith::CmpIPredicate::uge:
1974 return true;
1975 case arith::CmpIPredicate::ne:
1976 case arith::CmpIPredicate::slt:
1977 case arith::CmpIPredicate::sgt:
1978 case arith::CmpIPredicate::ult:
1979 case arith::CmpIPredicate::ugt:
1980 return false;
1981 }
1982 llvm_unreachable("unknown cmpi predicate kind");
1983}
1984
1985static std::optional<int64_t> getIntegerWidth(Type t) {
1986 if (auto intType = dyn_cast<IntegerType>(t)) {
1987 return intType.getWidth();
1988 }
1989 if (auto vectorIntType = dyn_cast<VectorType>(t)) {
1990 return llvm::cast<IntegerType>(vectorIntType.getElementType()).getWidth();
1991 }
1992 return std::nullopt;
1993}
1994
1995OpFoldResult arith::CmpIOp::fold(FoldAdaptor adaptor) {
1996 // cmpi(pred, x, x)
1997 if (getLhs() == getRhs()) {
1998 auto val = applyCmpPredicateToEqualOperands(getPredicate());
1999 return getBoolAttribute(getType(), val);
2000 }
2001
2002 if (matchPattern(adaptor.getRhs(), m_Zero())) {
2003 if (auto extOp = getLhs().getDefiningOp<ExtSIOp>()) {
2004 // extsi(%x : i1 -> iN) != 0 -> %x
2005 std::optional<int64_t> integerWidth =
2006 getIntegerWidth(extOp.getOperand().getType());
2007 if (integerWidth && integerWidth.value() == 1 &&
2008 getPredicate() == arith::CmpIPredicate::ne)
2009 return extOp.getOperand();
2010 }
2011 if (auto extOp = getLhs().getDefiningOp<ExtUIOp>()) {
2012 // extui(%x : i1 -> iN) != 0 -> %x
2013 std::optional<int64_t> integerWidth =
2014 getIntegerWidth(extOp.getOperand().getType());
2015 if (integerWidth && integerWidth.value() == 1 &&
2016 getPredicate() == arith::CmpIPredicate::ne)
2017 return extOp.getOperand();
2018 }
2019
2020 // arith.cmpi ne, %val, %zero : i1 -> %val
2021 if (getElementTypeOrSelf(getLhs().getType()).isInteger(1) &&
2022 getPredicate() == arith::CmpIPredicate::ne)
2023 return getLhs();
2024 }
2025
2026 if (matchPattern(adaptor.getRhs(), m_One())) {
2027 // arith.cmpi eq, %val, %one : i1 -> %val
2028 if (getElementTypeOrSelf(getLhs().getType()).isInteger(1) &&
2029 getPredicate() == arith::CmpIPredicate::eq)
2030 return getLhs();
2031 }
2032
2033 // Move constant to the right side.
2034 if (adaptor.getLhs() && !adaptor.getRhs()) {
2035 // Do not use invertPredicate, as it will change eq to ne and vice versa.
2036 using Pred = CmpIPredicate;
2037 const std::pair<Pred, Pred> invPreds[] = {
2038 {Pred::slt, Pred::sgt}, {Pred::sgt, Pred::slt}, {Pred::sle, Pred::sge},
2039 {Pred::sge, Pred::sle}, {Pred::ult, Pred::ugt}, {Pred::ugt, Pred::ult},
2040 {Pred::ule, Pred::uge}, {Pred::uge, Pred::ule}, {Pred::eq, Pred::eq},
2041 {Pred::ne, Pred::ne},
2042 };
2043 Pred origPred = getPredicate();
2044 for (auto pred : invPreds) {
2045 if (origPred == pred.first) {
2046 setPredicate(pred.second);
2047 Value lhs = getLhs();
2048 Value rhs = getRhs();
2049 getLhsMutable().assign(rhs);
2050 getRhsMutable().assign(lhs);
2051 return getResult();
2052 }
2053 }
2054 llvm_unreachable("unknown cmpi predicate kind");
2055 }
2056
2057 // We are moving constants to the right side; So if lhs is constant rhs is
2058 // guaranteed to be a constant.
2059 if (auto lhs = dyn_cast_if_present<TypedAttr>(adaptor.getLhs())) {
2061 adaptor.getOperands(), getI1SameShape(lhs.getType()),
2062 [pred = getPredicate()](const APInt &lhs, const APInt &rhs) {
2063 return APInt(1,
2064 static_cast<int64_t>(applyCmpPredicate(pred, lhs, rhs)));
2065 });
2066 }
2067
2068 return {};
2069}
2070
2071void arith::CmpIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
2072 MLIRContext *context) {
2073 patterns.insert<CmpIExtSI, CmpIExtUI>(context);
2074}
2075
2076//===----------------------------------------------------------------------===//
2077// CmpFOp
2078//===----------------------------------------------------------------------===//
2079
2080/// Compute `lhs` `pred` `rhs`, where `pred` is one of the known floating point
2081/// comparison predicates.
2082bool mlir::arith::applyCmpPredicate(arith::CmpFPredicate predicate,
2083 const APFloat &lhs, const APFloat &rhs) {
2084 auto cmpResult = lhs.compare(rhs);
2085 switch (predicate) {
2086 case arith::CmpFPredicate::AlwaysFalse:
2087 return false;
2088 case arith::CmpFPredicate::OEQ:
2089 return cmpResult == APFloat::cmpEqual;
2090 case arith::CmpFPredicate::OGT:
2091 return cmpResult == APFloat::cmpGreaterThan;
2092 case arith::CmpFPredicate::OGE:
2093 return cmpResult == APFloat::cmpGreaterThan ||
2094 cmpResult == APFloat::cmpEqual;
2095 case arith::CmpFPredicate::OLT:
2096 return cmpResult == APFloat::cmpLessThan;
2097 case arith::CmpFPredicate::OLE:
2098 return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
2099 case arith::CmpFPredicate::ONE:
2100 return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual;
2101 case arith::CmpFPredicate::ORD:
2102 return cmpResult != APFloat::cmpUnordered;
2103 case arith::CmpFPredicate::UEQ:
2104 return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual;
2105 case arith::CmpFPredicate::UGT:
2106 return cmpResult == APFloat::cmpUnordered ||
2107 cmpResult == APFloat::cmpGreaterThan;
2108 case arith::CmpFPredicate::UGE:
2109 return cmpResult == APFloat::cmpUnordered ||
2110 cmpResult == APFloat::cmpGreaterThan ||
2111 cmpResult == APFloat::cmpEqual;
2112 case arith::CmpFPredicate::ULT:
2113 return cmpResult == APFloat::cmpUnordered ||
2114 cmpResult == APFloat::cmpLessThan;
2115 case arith::CmpFPredicate::ULE:
2116 return cmpResult == APFloat::cmpUnordered ||
2117 cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
2118 case arith::CmpFPredicate::UNE:
2119 return cmpResult != APFloat::cmpEqual;
2120 case arith::CmpFPredicate::UNO:
2121 return cmpResult == APFloat::cmpUnordered;
2122 case arith::CmpFPredicate::AlwaysTrue:
2123 return true;
2124 }
2125 llvm_unreachable("unknown cmpf predicate kind");
2126}
2127
2128OpFoldResult arith::CmpFOp::fold(FoldAdaptor adaptor) {
2129 auto lhs = dyn_cast_if_present<FloatAttr>(adaptor.getLhs());
2130 auto rhs = dyn_cast_if_present<FloatAttr>(adaptor.getRhs());
2131
2132 // If one operand is NaN, making them both NaN does not change the result.
2133 if (lhs && lhs.getValue().isNaN())
2134 rhs = lhs;
2135 if (rhs && rhs.getValue().isNaN())
2136 lhs = rhs;
2137
2138 if (!lhs || !rhs)
2139 return {};
2140
2141 auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
2142 return BoolAttr::get(getContext(), val);
2143}
2144
2145class CmpFIntToFPConst final : public OpRewritePattern<CmpFOp> {
2146public:
2147 using Base::Base;
2148
2149 static CmpIPredicate convertToIntegerPredicate(CmpFPredicate pred,
2150 bool isUnsigned) {
2151 using namespace arith;
2152 switch (pred) {
2153 case CmpFPredicate::UEQ:
2154 case CmpFPredicate::OEQ:
2155 return CmpIPredicate::eq;
2156 case CmpFPredicate::UGT:
2157 case CmpFPredicate::OGT:
2158 return isUnsigned ? CmpIPredicate::ugt : CmpIPredicate::sgt;
2159 case CmpFPredicate::UGE:
2160 case CmpFPredicate::OGE:
2161 return isUnsigned ? CmpIPredicate::uge : CmpIPredicate::sge;
2162 case CmpFPredicate::ULT:
2163 case CmpFPredicate::OLT:
2164 return isUnsigned ? CmpIPredicate::ult : CmpIPredicate::slt;
2165 case CmpFPredicate::ULE:
2166 case CmpFPredicate::OLE:
2167 return isUnsigned ? CmpIPredicate::ule : CmpIPredicate::sle;
2168 case CmpFPredicate::UNE:
2169 case CmpFPredicate::ONE:
2170 return CmpIPredicate::ne;
2171 default:
2172 llvm_unreachable("Unexpected predicate!");
2173 }
2174 }
2175
2176 LogicalResult matchAndRewrite(CmpFOp op,
2177 PatternRewriter &rewriter) const override {
2178 FloatAttr flt;
2179 if (!matchPattern(op.getRhs(), m_Constant(&flt)))
2180 return failure();
2181
2182 const APFloat &rhs = flt.getValue();
2183
2184 // Don't attempt to fold a nan.
2185 if (rhs.isNaN())
2186 return failure();
2187
2188 // Get the width of the mantissa. We don't want to hack on conversions that
2189 // might lose information from the integer, e.g. "i64 -> float"
2190 FloatType floatTy = llvm::cast<FloatType>(op.getRhs().getType());
2191 int mantissaWidth = floatTy.getFPMantissaWidth();
2192 if (mantissaWidth <= 0)
2193 return failure();
2194
2195 bool isUnsigned;
2196 Value intVal;
2197
2198 if (auto si = op.getLhs().getDefiningOp<SIToFPOp>()) {
2199 isUnsigned = false;
2200 intVal = si.getIn();
2201 } else if (auto ui = op.getLhs().getDefiningOp<UIToFPOp>()) {
2202 isUnsigned = true;
2203 intVal = ui.getIn();
2204 } else {
2205 return failure();
2206 }
2207
2208 // Check to see that the input is converted from an integer type that is
2209 // small enough that preserves all bits.
2210 auto intTy = llvm::cast<IntegerType>(intVal.getType());
2211 auto intWidth = intTy.getWidth();
2212
2213 // Number of bits representing values, as opposed to the sign
2214 auto valueBits = isUnsigned ? intWidth : (intWidth - 1);
2215
2216 // Following test does NOT adjust intWidth downwards for signed inputs,
2217 // because the most negative value still requires all the mantissa bits
2218 // to distinguish it from one less than that value.
2219 if ((int)intWidth > mantissaWidth) {
2220 // Conversion would lose accuracy. Check if loss can impact comparison.
2221 int exponent = ilogb(rhs);
2222 if (exponent == APFloat::IEK_Inf) {
2223 int maxExponent = ilogb(APFloat::getLargest(rhs.getSemantics()));
2224 if (maxExponent < (int)valueBits) {
2225 // Conversion could create infinity.
2226 return failure();
2227 }
2228 } else {
2229 // Note that if rhs is zero or NaN, then Exp is negative
2230 // and first condition is trivially false.
2231 if (mantissaWidth <= exponent && exponent <= (int)valueBits) {
2232 // Conversion could affect comparison.
2233 return failure();
2234 }
2235 }
2236 }
2237
2238 // Convert to equivalent cmpi predicate
2239 CmpIPredicate pred;
2240 switch (op.getPredicate()) {
2241 case CmpFPredicate::ORD:
2242 // Int to fp conversion doesn't create a nan (ord checks neither is a nan)
2243 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2244 /*width=*/1);
2245 return success();
2246 case CmpFPredicate::UNO:
2247 // Int to fp conversion doesn't create a nan (uno checks either is a nan)
2248 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2249 /*width=*/1);
2250 return success();
2251 default:
2252 pred = convertToIntegerPredicate(op.getPredicate(), isUnsigned);
2253 break;
2254 }
2255
2256 if (!isUnsigned) {
2257 // If the rhs value is > SignedMax, fold the comparison. This handles
2258 // +INF and large values.
2259 APFloat signedMax(rhs.getSemantics());
2260 signedMax.convertFromAPInt(APInt::getSignedMaxValue(intWidth), true,
2261 APFloat::rmNearestTiesToEven);
2262 if (signedMax < rhs) { // smax < 13123.0
2263 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::slt ||
2264 pred == CmpIPredicate::sle)
2265 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2266 /*width=*/1);
2267 else
2268 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2269 /*width=*/1);
2270 return success();
2271 }
2272 } else {
2273 // If the rhs value is > UnsignedMax, fold the comparison. This handles
2274 // +INF and large values.
2275 APFloat unsignedMax(rhs.getSemantics());
2276 unsignedMax.convertFromAPInt(APInt::getMaxValue(intWidth), false,
2277 APFloat::rmNearestTiesToEven);
2278 if (unsignedMax < rhs) { // umax < 13123.0
2279 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ult ||
2280 pred == CmpIPredicate::ule)
2281 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2282 /*width=*/1);
2283 else
2284 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2285 /*width=*/1);
2286 return success();
2287 }
2288 }
2289
2290 if (!isUnsigned) {
2291 // See if the rhs value is < SignedMin.
2292 APFloat signedMin(rhs.getSemantics());
2293 signedMin.convertFromAPInt(APInt::getSignedMinValue(intWidth), true,
2294 APFloat::rmNearestTiesToEven);
2295 if (signedMin > rhs) { // smin > 12312.0
2296 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::sgt ||
2297 pred == CmpIPredicate::sge)
2298 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2299 /*width=*/1);
2300 else
2301 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2302 /*width=*/1);
2303 return success();
2304 }
2305 } else {
2306 // See if the rhs value is < UnsignedMin.
2307 APFloat unsignedMin(rhs.getSemantics());
2308 unsignedMin.convertFromAPInt(APInt::getMinValue(intWidth), false,
2309 APFloat::rmNearestTiesToEven);
2310 if (unsignedMin > rhs) { // umin > 12312.0
2311 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ugt ||
2312 pred == CmpIPredicate::uge)
2313 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2314 /*width=*/1);
2315 else
2316 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2317 /*width=*/1);
2318 return success();
2319 }
2320 }
2321
2322 // Okay, now we know that the FP constant fits in the range [SMIN, SMAX] or
2323 // [0, UMAX], but it may still be fractional. See if it is fractional by
2324 // casting the FP value to the integer value and back, checking for
2325 // equality. Don't do this for zero, because -0.0 is not fractional.
2326 bool ignored;
2327 APSInt rhsInt(intWidth, isUnsigned);
2328 if (APFloat::opInvalidOp ==
2329 rhs.convertToInteger(rhsInt, APFloat::rmTowardZero, &ignored)) {
2330 // Undefined behavior invoked - the destination type can't represent
2331 // the input constant.
2332 return failure();
2333 }
2334
2335 if (!rhs.isZero()) {
2336 APFloat apf(floatTy.getFloatSemantics(),
2337 APInt::getZero(floatTy.getWidth()));
2338 apf.convertFromAPInt(rhsInt, !isUnsigned, APFloat::rmNearestTiesToEven);
2339
2340 bool equal = apf == rhs;
2341 if (!equal) {
2342 // If we had a comparison against a fractional value, we have to adjust
2343 // the compare predicate and sometimes the value. rhsInt is rounded
2344 // towards zero at this point.
2345 switch (pred) {
2346 case CmpIPredicate::ne: // (float)int != 4.4 --> true
2347 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2348 /*width=*/1);
2349 return success();
2350 case CmpIPredicate::eq: // (float)int == 4.4 --> false
2351 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2352 /*width=*/1);
2353 return success();
2354 case CmpIPredicate::ule:
2355 // (float)int <= 4.4 --> int <= 4
2356 // (float)int <= -4.4 --> false
2357 if (rhs.isNegative()) {
2358 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2359 /*width=*/1);
2360 return success();
2361 }
2362 break;
2363 case CmpIPredicate::sle:
2364 // (float)int <= 4.4 --> int <= 4
2365 // (float)int <= -4.4 --> int < -4
2366 if (rhs.isNegative())
2367 pred = CmpIPredicate::slt;
2368 break;
2369 case CmpIPredicate::ult:
2370 // (float)int < -4.4 --> false
2371 // (float)int < 4.4 --> int <= 4
2372 if (rhs.isNegative()) {
2373 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
2374 /*width=*/1);
2375 return success();
2376 }
2377 pred = CmpIPredicate::ule;
2378 break;
2379 case CmpIPredicate::slt:
2380 // (float)int < -4.4 --> int < -4
2381 // (float)int < 4.4 --> int <= 4
2382 if (!rhs.isNegative())
2383 pred = CmpIPredicate::sle;
2384 break;
2385 case CmpIPredicate::ugt:
2386 // (float)int > 4.4 --> int > 4
2387 // (float)int > -4.4 --> true
2388 if (rhs.isNegative()) {
2389 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2390 /*width=*/1);
2391 return success();
2392 }
2393 break;
2394 case CmpIPredicate::sgt:
2395 // (float)int > 4.4 --> int > 4
2396 // (float)int > -4.4 --> int >= -4
2397 if (rhs.isNegative())
2398 pred = CmpIPredicate::sge;
2399 break;
2400 case CmpIPredicate::uge:
2401 // (float)int >= -4.4 --> true
2402 // (float)int >= 4.4 --> int > 4
2403 if (rhs.isNegative()) {
2404 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
2405 /*width=*/1);
2406 return success();
2407 }
2408 pred = CmpIPredicate::ugt;
2409 break;
2410 case CmpIPredicate::sge:
2411 // (float)int >= -4.4 --> int >= -4
2412 // (float)int >= 4.4 --> int > 4
2413 if (!rhs.isNegative())
2414 pred = CmpIPredicate::sgt;
2415 break;
2416 }
2417 }
2418 }
2419
2420 // Lower this FP comparison into an appropriate integer version of the
2421 // comparison.
2422 rewriter.replaceOpWithNewOp<CmpIOp>(
2423 op, pred, intVal,
2424 ConstantOp::create(rewriter, op.getLoc(), intVal.getType(),
2425 rewriter.getIntegerAttr(intVal.getType(), rhsInt)));
2426 return success();
2427 }
2428};
2429
2430void arith::CmpFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
2431 MLIRContext *context) {
2432 patterns.insert<CmpFIntToFPConst>(context);
2433}
2434
2435//===----------------------------------------------------------------------===//
2436// SelectOp
2437//===----------------------------------------------------------------------===//
2438
2439// select %arg, %c1, %c0 => extui %arg
2440struct SelectToExtUI : public OpRewritePattern<arith::SelectOp> {
2441 using Base::Base;
2442
2443 LogicalResult matchAndRewrite(arith::SelectOp op,
2444 PatternRewriter &rewriter) const override {
2445 // Cannot extui i1 to i1, or i1 to f32
2446 if (!llvm::isa<IntegerType>(op.getType()) || op.getType().isInteger(1))
2447 return failure();
2448
2449 // select %x, c1, %c0 => extui %arg
2450 if (matchPattern(op.getTrueValue(), m_One()) &&
2451 matchPattern(op.getFalseValue(), m_Zero())) {
2452 rewriter.replaceOpWithNewOp<arith::ExtUIOp>(op, op.getType(),
2453 op.getCondition());
2454 return success();
2455 }
2456
2457 // select %x, c0, %c1 => extui (xor %arg, true)
2458 if (matchPattern(op.getTrueValue(), m_Zero()) &&
2459 matchPattern(op.getFalseValue(), m_One())) {
2460 rewriter.replaceOpWithNewOp<arith::ExtUIOp>(
2461 op, op.getType(),
2462 arith::XOrIOp::create(
2463 rewriter, op.getLoc(), op.getCondition(),
2464 arith::ConstantIntOp::create(rewriter, op.getLoc(),
2465 op.getCondition().getType(), 1)));
2466 return success();
2467 }
2468
2469 return failure();
2470 }
2471};
2472
2473void arith::SelectOp::getCanonicalizationPatterns(RewritePatternSet &results,
2474 MLIRContext *context) {
2475 results.add<RedundantSelectFalse, RedundantSelectTrue, SelectNotCond,
2476 SelectI1ToNot, SelectToExtUI>(context);
2477}
2478
2479OpFoldResult arith::SelectOp::fold(FoldAdaptor adaptor) {
2480 Value trueVal = getTrueValue();
2481 Value falseVal = getFalseValue();
2482 if (trueVal == falseVal)
2483 return trueVal;
2484
2485 Value condition = getCondition();
2486
2487 // select true, %0, %1 => %0
2488 if (matchPattern(adaptor.getCondition(), m_One()))
2489 return trueVal;
2490
2491 // select false, %0, %1 => %1
2492 if (matchPattern(adaptor.getCondition(), m_Zero()))
2493 return falseVal;
2494
2495 // If either operand is fully poisoned, return the other.
2496 if (isa_and_nonnull<ub::PoisonAttr>(adaptor.getTrueValue()))
2497 return falseVal;
2498
2499 if (isa_and_nonnull<ub::PoisonAttr>(adaptor.getFalseValue()))
2500 return trueVal;
2501
2502 // select %x, true, false => %x
2503 if (getType().isSignlessInteger(1) &&
2504 matchPattern(adaptor.getTrueValue(), m_One()) &&
2505 matchPattern(adaptor.getFalseValue(), m_Zero()))
2506 return condition;
2507
2508 if (auto cmp = condition.getDefiningOp<arith::CmpIOp>()) {
2509 auto pred = cmp.getPredicate();
2510 if (pred == arith::CmpIPredicate::eq || pred == arith::CmpIPredicate::ne) {
2511 auto cmpLhs = cmp.getLhs();
2512 auto cmpRhs = cmp.getRhs();
2513
2514 // %0 = arith.cmpi eq, %arg0, %arg1
2515 // %1 = arith.select %0, %arg0, %arg1 => %arg1
2516
2517 // %0 = arith.cmpi ne, %arg0, %arg1
2518 // %1 = arith.select %0, %arg0, %arg1 => %arg0
2519
2520 if ((cmpLhs == trueVal && cmpRhs == falseVal) ||
2521 (cmpRhs == trueVal && cmpLhs == falseVal))
2522 return pred == arith::CmpIPredicate::ne ? trueVal : falseVal;
2523 }
2524 }
2525
2526 // Constant-fold constant operands over non-splat constant condition.
2527 // select %cst_vec, %cst0, %cst1 => %cst2
2528 if (auto cond =
2529 dyn_cast_if_present<DenseElementsAttr>(adaptor.getCondition())) {
2530 if (auto lhs =
2531 dyn_cast_if_present<DenseElementsAttr>(adaptor.getTrueValue())) {
2532 if (auto rhs =
2533 dyn_cast_if_present<DenseElementsAttr>(adaptor.getFalseValue())) {
2534 SmallVector<Attribute> results;
2535 results.reserve(static_cast<size_t>(cond.getNumElements()));
2536 auto condVals = llvm::make_range(cond.value_begin<BoolAttr>(),
2537 cond.value_end<BoolAttr>());
2538 auto lhsVals = llvm::make_range(lhs.value_begin<Attribute>(),
2539 lhs.value_end<Attribute>());
2540 auto rhsVals = llvm::make_range(rhs.value_begin<Attribute>(),
2541 rhs.value_end<Attribute>());
2542
2543 for (auto [condVal, lhsVal, rhsVal] :
2544 llvm::zip_equal(condVals, lhsVals, rhsVals))
2545 results.push_back(condVal.getValue() ? lhsVal : rhsVal);
2546
2547 return DenseElementsAttr::get(lhs.getType(), results);
2548 }
2549 }
2550 }
2551
2552 return nullptr;
2553}
2554
2555ParseResult SelectOp::parse(OpAsmParser &parser, OperationState &result) {
2556 Type conditionType, resultType;
2557 SmallVector<OpAsmParser::UnresolvedOperand, 3> operands;
2558 if (parser.parseOperandList(operands, /*requiredOperandCount=*/3) ||
2559 parser.parseOptionalAttrDict(result.attributes) ||
2560 parser.parseColonType(resultType))
2561 return failure();
2562
2563 // Check for the explicit condition type if this is a masked tensor or vector.
2564 if (succeeded(parser.parseOptionalComma())) {
2565 conditionType = resultType;
2566 if (parser.parseType(resultType))
2567 return failure();
2568 } else {
2569 conditionType = parser.getBuilder().getI1Type();
2570 }
2571
2572 result.addTypes(resultType);
2573 return parser.resolveOperands(operands,
2574 {conditionType, resultType, resultType},
2575 parser.getNameLoc(), result.operands);
2576}
2577
2578void arith::SelectOp::print(OpAsmPrinter &p) {
2579 p << " " << getOperands();
2580 p.printOptionalAttrDict((*this)->getAttrs());
2581 p << " : ";
2582 if (ShapedType condType = dyn_cast<ShapedType>(getCondition().getType()))
2583 p << condType << ", ";
2584 p << getType();
2585}
2586
2587LogicalResult arith::SelectOp::verify() {
2588 Type conditionType = getCondition().getType();
2589 if (conditionType.isSignlessInteger(1))
2590 return success();
2591
2592 // If the result type is a vector or tensor, the type can be a mask with the
2593 // same elements.
2594 Type resultType = getType();
2595 if (!llvm::isa<TensorType, VectorType>(resultType))
2596 return emitOpError() << "expected condition to be a signless i1, but got "
2597 << conditionType;
2598 Type shapedConditionType = getI1SameShape(resultType);
2599 if (conditionType != shapedConditionType) {
2600 return emitOpError() << "expected condition type to have the same shape "
2601 "as the result type, expected "
2602 << shapedConditionType << ", but got "
2603 << conditionType;
2604 }
2605 return success();
2606}
2607//===----------------------------------------------------------------------===//
2608// ShLIOp
2609//===----------------------------------------------------------------------===//
2610
2611OpFoldResult arith::ShLIOp::fold(FoldAdaptor adaptor) {
2612 // shli(x, 0) -> x
2613 if (matchPattern(adaptor.getRhs(), m_Zero()))
2614 return getLhs();
2615 // Don't fold if shifting more or equal than the bit width.
2616 bool bounded = false;
2618 adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
2619 bounded = b.ult(b.getBitWidth());
2620 return a.shl(b);
2621 });
2622 return bounded ? result : Attribute();
2623}
2624
2625//===----------------------------------------------------------------------===//
2626// ShRUIOp
2627//===----------------------------------------------------------------------===//
2628
2629OpFoldResult arith::ShRUIOp::fold(FoldAdaptor adaptor) {
2630 // shrui(x, 0) -> x
2631 if (matchPattern(adaptor.getRhs(), m_Zero()))
2632 return getLhs();
2633 // Don't fold if shifting more or equal than the bit width.
2634 bool bounded = false;
2636 adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
2637 bounded = b.ult(b.getBitWidth());
2638 return a.lshr(b);
2639 });
2640 return bounded ? result : Attribute();
2641}
2642
2643//===----------------------------------------------------------------------===//
2644// ShRSIOp
2645//===----------------------------------------------------------------------===//
2646
2647OpFoldResult arith::ShRSIOp::fold(FoldAdaptor adaptor) {
2648 // shrsi(x, 0) -> x
2649 if (matchPattern(adaptor.getRhs(), m_Zero()))
2650 return getLhs();
2651 // Don't fold if shifting more or equal than the bit width.
2652 bool bounded = false;
2654 adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
2655 bounded = b.ult(b.getBitWidth());
2656 return a.ashr(b);
2657 });
2658 return bounded ? result : Attribute();
2659}
2660
2661//===----------------------------------------------------------------------===//
2662// Atomic Enum
2663//===----------------------------------------------------------------------===//
2664
2665/// Returns the identity value attribute associated with an AtomicRMWKind op.
2666TypedAttr mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
2667 OpBuilder &builder, Location loc,
2668 bool useOnlyFiniteValue) {
2669 switch (kind) {
2670 case AtomicRMWKind::maximumf: {
2671 const llvm::fltSemantics &semantic =
2672 llvm::cast<FloatType>(resultType).getFloatSemantics();
2673 APFloat identity = useOnlyFiniteValue
2674 ? APFloat::getLargest(semantic, /*Negative=*/true)
2675 : APFloat::getInf(semantic, /*Negative=*/true);
2676 return builder.getFloatAttr(resultType, identity);
2677 }
2678 case AtomicRMWKind::maxnumf: {
2679 const llvm::fltSemantics &semantic =
2680 llvm::cast<FloatType>(resultType).getFloatSemantics();
2681 APFloat identity = APFloat::getNaN(semantic, /*Negative=*/true);
2682 return builder.getFloatAttr(resultType, identity);
2683 }
2684 case AtomicRMWKind::addf:
2685 case AtomicRMWKind::addi:
2686 case AtomicRMWKind::maxu:
2687 case AtomicRMWKind::ori:
2688 case AtomicRMWKind::xori:
2689 return builder.getZeroAttr(resultType);
2690 case AtomicRMWKind::andi:
2691 return builder.getIntegerAttr(
2692 resultType,
2693 APInt::getAllOnes(llvm::cast<IntegerType>(resultType).getWidth()));
2694 case AtomicRMWKind::maxs:
2695 return builder.getIntegerAttr(
2696 resultType, APInt::getSignedMinValue(
2697 llvm::cast<IntegerType>(resultType).getWidth()));
2698 case AtomicRMWKind::minimumf: {
2699 const llvm::fltSemantics &semantic =
2700 llvm::cast<FloatType>(resultType).getFloatSemantics();
2701 APFloat identity = useOnlyFiniteValue
2702 ? APFloat::getLargest(semantic, /*Negative=*/false)
2703 : APFloat::getInf(semantic, /*Negative=*/false);
2704
2705 return builder.getFloatAttr(resultType, identity);
2706 }
2707 case AtomicRMWKind::minnumf: {
2708 const llvm::fltSemantics &semantic =
2709 llvm::cast<FloatType>(resultType).getFloatSemantics();
2710 APFloat identity = APFloat::getNaN(semantic, /*Negative=*/false);
2711 return builder.getFloatAttr(resultType, identity);
2712 }
2713 case AtomicRMWKind::mins:
2714 return builder.getIntegerAttr(
2715 resultType, APInt::getSignedMaxValue(
2716 llvm::cast<IntegerType>(resultType).getWidth()));
2717 case AtomicRMWKind::minu:
2718 return builder.getIntegerAttr(
2719 resultType,
2720 APInt::getMaxValue(llvm::cast<IntegerType>(resultType).getWidth()));
2721 case AtomicRMWKind::muli:
2722 return builder.getIntegerAttr(resultType, 1);
2723 case AtomicRMWKind::mulf:
2724 return builder.getFloatAttr(resultType, 1);
2725 // TODO: Add remaining reduction operations.
2726 default:
2727 (void)emitOptionalError(loc, "Reduction operation type not supported");
2728 break;
2729 }
2730 return nullptr;
2731}
2732
2733/// Returns the identity numeric value of the given op.
2734std::optional<TypedAttr> mlir::arith::getNeutralElement(Operation *op) {
2735 std::optional<AtomicRMWKind> maybeKind =
2737 // Floating-point operations.
2738 .Case([](arith::AddFOp op) { return AtomicRMWKind::addf; })
2739 .Case([](arith::MulFOp op) { return AtomicRMWKind::mulf; })
2740 .Case([](arith::MaximumFOp op) { return AtomicRMWKind::maximumf; })
2741 .Case([](arith::MinimumFOp op) { return AtomicRMWKind::minimumf; })
2742 .Case([](arith::MaxNumFOp op) { return AtomicRMWKind::maxnumf; })
2743 .Case([](arith::MinNumFOp op) { return AtomicRMWKind::minnumf; })
2744 // Integer operations.
2745 .Case([](arith::AddIOp op) { return AtomicRMWKind::addi; })
2746 .Case([](arith::OrIOp op) { return AtomicRMWKind::ori; })
2747 .Case([](arith::XOrIOp op) { return AtomicRMWKind::xori; })
2748 .Case([](arith::AndIOp op) { return AtomicRMWKind::andi; })
2749 .Case([](arith::MaxUIOp op) { return AtomicRMWKind::maxu; })
2750 .Case([](arith::MinUIOp op) { return AtomicRMWKind::minu; })
2751 .Case([](arith::MaxSIOp op) { return AtomicRMWKind::maxs; })
2752 .Case([](arith::MinSIOp op) { return AtomicRMWKind::mins; })
2753 .Case([](arith::MulIOp op) { return AtomicRMWKind::muli; })
2754 .Default(std::nullopt);
2755 if (!maybeKind) {
2756 return std::nullopt;
2757 }
2758
2759 bool useOnlyFiniteValue = false;
2760 auto fmfOpInterface = dyn_cast<ArithFastMathInterface>(op);
2761 if (fmfOpInterface) {
2762 arith::FastMathFlagsAttr fmfAttr = fmfOpInterface.getFastMathFlagsAttr();
2763 useOnlyFiniteValue =
2764 bitEnumContainsAny(fmfAttr.getValue(), arith::FastMathFlags::ninf);
2765 }
2766
2767 // Builder only used as helper for attribute creation.
2768 OpBuilder b(op->getContext());
2769 Type resultType = op->getResult(0).getType();
2770
2771 return getIdentityValueAttr(*maybeKind, resultType, b, op->getLoc(),
2772 useOnlyFiniteValue);
2773}
2774
2775/// Returns the identity value associated with an AtomicRMWKind op.
2776Value mlir::arith::getIdentityValue(AtomicRMWKind op, Type resultType,
2777 OpBuilder &builder, Location loc,
2778 bool useOnlyFiniteValue) {
2779 auto attr =
2780 getIdentityValueAttr(op, resultType, builder, loc, useOnlyFiniteValue);
2781 return arith::ConstantOp::create(builder, loc, attr);
2782}
2783
2784/// Return the value obtained by applying the reduction operation kind
2785/// associated with a binary AtomicRMWKind op to `lhs` and `rhs`.
2787 Location loc, Value lhs, Value rhs) {
2788 switch (op) {
2789 case AtomicRMWKind::addf:
2790 return arith::AddFOp::create(builder, loc, lhs, rhs);
2791 case AtomicRMWKind::addi:
2792 return arith::AddIOp::create(builder, loc, lhs, rhs);
2793 case AtomicRMWKind::mulf:
2794 return arith::MulFOp::create(builder, loc, lhs, rhs);
2795 case AtomicRMWKind::muli:
2796 return arith::MulIOp::create(builder, loc, lhs, rhs);
2797 case AtomicRMWKind::maximumf:
2798 return arith::MaximumFOp::create(builder, loc, lhs, rhs);
2799 case AtomicRMWKind::minimumf:
2800 return arith::MinimumFOp::create(builder, loc, lhs, rhs);
2801 case AtomicRMWKind::maxnumf:
2802 return arith::MaxNumFOp::create(builder, loc, lhs, rhs);
2803 case AtomicRMWKind::minnumf:
2804 return arith::MinNumFOp::create(builder, loc, lhs, rhs);
2805 case AtomicRMWKind::maxs:
2806 return arith::MaxSIOp::create(builder, loc, lhs, rhs);
2807 case AtomicRMWKind::mins:
2808 return arith::MinSIOp::create(builder, loc, lhs, rhs);
2809 case AtomicRMWKind::maxu:
2810 return arith::MaxUIOp::create(builder, loc, lhs, rhs);
2811 case AtomicRMWKind::minu:
2812 return arith::MinUIOp::create(builder, loc, lhs, rhs);
2813 case AtomicRMWKind::ori:
2814 return arith::OrIOp::create(builder, loc, lhs, rhs);
2815 case AtomicRMWKind::andi:
2816 return arith::AndIOp::create(builder, loc, lhs, rhs);
2817 case AtomicRMWKind::xori:
2818 return arith::XOrIOp::create(builder, loc, lhs, rhs);
2819 // TODO: Add remaining reduction operations.
2820 default:
2821 (void)emitOptionalError(loc, "Reduction operation type not supported");
2822 break;
2823 }
2824 return nullptr;
2825}
2826
2827//===----------------------------------------------------------------------===//
2828// TableGen'd op method definitions
2829//===----------------------------------------------------------------------===//
2830
2831#define GET_OP_CLASSES
2832#include "mlir/Dialect/Arith/IR/ArithOps.cpp.inc"
2833
2834//===----------------------------------------------------------------------===//
2835// TableGen'd enum attribute definitions
2836//===----------------------------------------------------------------------===//
2837
2838#include "mlir/Dialect/Arith/IR/ArithOpsEnums.cpp.inc"
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static Speculation::Speculatability getDivUISpeculatability(Value divisor)
Returns whether an unsigned division by divisor is speculatable.
Definition ArithOps.cpp:711
static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs)
Validate a cast that changes the width of a type.
static IntegerAttr mulIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
Definition ArithOps.cpp:61
static IntegerOverflowFlagsAttr mergeOverflowFlags(IntegerOverflowFlagsAttr val1, IntegerOverflowFlagsAttr val2)
Definition ArithOps.cpp:68
static llvm::RoundingMode convertArithRoundingModeToLLVMIR(RoundingMode roundingMode)
Equivalent to convertRoundingModeToLLVM(convertArithRoundingModeToLLVM(roundingMode)).
Definition ArithOps.cpp:108
static Type getTypeIfLike(Type type)
Get allowed underlying types for vectors and tensors.
static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate)
Returns true if the predicate is true for two equal operands.
static Value foldDivMul(Value lhs, Value rhs, arith::IntegerOverflowFlags ovfFlags)
Fold (a * b) / b -> a
Definition ArithOps.cpp:672
static bool hasSameEncoding(Type typeA, Type typeB)
Return false if both types are ranked tensor with mismatching encoding.
static Type getUnderlyingType(Type type, type_list< ShapedTypes... >, type_list< ElementTypes... >)
Returns a non-null type only if the provided type is one of the allowed types or one of the allowed s...
static APInt signedCeilNonnegInputs(const APInt &a, const APInt &b, bool &overflow)
Definition ArithOps.cpp:771
static std::optional< int64_t > getIntegerWidth(Type t)
static Speculation::Speculatability getDivSISpeculatability(Value divisor)
Returns whether a signed division by divisor is speculatable.
Definition ArithOps.cpp:753
static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
Definition ArithOps.cpp:51
static Attribute getBoolAttribute(Type type, bool value)
Definition ArithOps.cpp:149
static bool areIndexCastCompatible(TypeRange inputs, TypeRange outputs)
static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs)
static LogicalResult verifyExtOp(Op op)
static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs)
Definition ArithOps.cpp:56
static int64_t getScalarOrElementWidth(Type type)
Definition ArithOps.cpp:129
static FailureOr< APFloat > convertFloatValue(APFloat sourceValue, const llvm::fltSemantics &targetSemantics, llvm::RoundingMode roundingMode=llvm::RoundingMode::NearestTiesToEven)
Attempts to convert sourceValue to an APFloat value with targetSemantics and roundingMode,...
static Value foldAndIofAndI(arith::AndIOp op)
Fold and(a, and(a, b)) to and(a, b)
Definition ArithOps.cpp:952
static Type getTypeIfLikeOrMemRef(Type type)
Get allowed underlying types for vectors, tensors, and memrefs.
static Type getI1SameShape(Type type)
Return the type of the same shape (scalar, vector or tensor) containing i1.
Definition ArithOps.cpp:170
static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs)
std::tuple< Types... > * type_list
static IntegerAttr applyToIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs, function_ref< APInt(const APInt &, const APInt &)> binFn)
Definition ArithOps.cpp:42
static APInt calculateUnsignedOverflow(const APInt &sum, const APInt &operand)
Definition ArithOps.cpp:433
static FailureOr< APInt > getIntOrSplatIntValue(Attribute attr)
Definition ArithOps.cpp:141
static LogicalResult verifyTruncateOp(Op op)
lhs
static Type getElementType(Type type)
Determine the element type of type.
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
if(!isCopyOut)
b getContext())
#define mul(a, b)
#define add(a, b)
#define div(a, b)
LogicalResult matchAndRewrite(CmpFOp op, PatternRewriter &rewriter) const override
static CmpIPredicate convertToIntegerPredicate(CmpFPredicate pred, bool isUnsigned)
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
Attributes are known-constant values of operations.
Definition Attributes.h:25
static BoolAttr get(MLIRContext *context, bool value)
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:108
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:228
FloatAttr getFloatAttr(Type type, double value)
Definition Builders.cpp:254
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:67
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition Builders.h:91
TypedAttr getZeroAttr(Type type)
Definition Builders.cpp:324
IntegerType getI1Type()
Definition Builders.cpp:53
IndexType getIndexType()
Definition Builders.cpp:51
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Definition Builders.h:630
Location getLoc() const
Accessors for the implied location.
Definition Builders.h:663
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
This class helps build Operations.
Definition Builders.h:207
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition Builders.cpp:457
This class represents a single result from folding an operation.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:216
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition Types.cpp:35
bool isSignlessInteger() const
Return true if this is a signless integer type (with the specified width).
Definition Types.cpp:64
bool isIndex() const
Definition Types.cpp:54
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition Types.cpp:116
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:122
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
Specialization of arith.constant op that returns a floating point value.
Definition Arith.h:92
static ConstantFloatOp create(OpBuilder &builder, Location location, FloatType type, const APFloat &value)
Definition ArithOps.cpp:330
static bool classof(Operation *op)
Definition ArithOps.cpp:347
static void build(OpBuilder &builder, OperationState &result, FloatType type, const APFloat &value)
Build a constant float op that produces a float of the specified type.
Definition ArithOps.cpp:324
Specialization of arith.constant op that returns an integer of index type.
Definition Arith.h:113
static void build(OpBuilder &builder, OperationState &result, int64_t value)
Build a constant int op that produces an index.
Definition ArithOps.cpp:353
static bool classof(Operation *op)
Definition ArithOps.cpp:374
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:359
Specialization of arith.constant op that returns an integer value.
Definition Arith.h:54
static ConstantIntOp create(OpBuilder &builder, Location location, int64_t value, unsigned width)
Definition ArithOps.cpp:258
static void build(OpBuilder &builder, OperationState &result, int64_t value, unsigned width)
Build a constant int op that produces an integer of the specified width.
Definition ArithOps.cpp:251
static bool classof(Operation *op)
Definition ArithOps.cpp:318
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
std::optional< TypedAttr > getNeutralElement(Operation *op)
Return the identity numeric value associated to the give op.
bool applyCmpPredicate(arith::CmpIPredicate predicate, const APInt &lhs, const APInt &rhs)
Compute lhs pred rhs, where pred is one of the known integer comparison predicates.
TypedAttr getIdentityValueAttr(AtomicRMWKind kind, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value attribute associated with an AtomicRMWKind op.
Value getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc, Value lhs, Value rhs)
Returns the value obtained by applying the reduction operation kind associated with a binary AtomicRM...
Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value associated with an AtomicRMWKind op.
arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred)
Invert an integer comparison predicate.
Definition ArithOps.cpp:75
Value getZeroConstant(OpBuilder &builder, Location loc, Type type)
Creates an arith.constant operation with a zero value of type type.
Definition ArithOps.cpp:380
auto m_Val(Value v)
Definition Matchers.h:539
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition Matchers.h:527
detail::constant_float_predicate_matcher m_NaNFloat()
Matches a constant scalar / vector splat / tensor splat float ones.
Definition Matchers.h:421
LogicalResult verifyCompatibleShapes(TypeRange types1, TypeRange types2)
Returns success if the given two arrays have the same number of elements and each pair wise entries h...
Attribute constFoldCastOp(ArrayRef< Attribute > operands, Type resType, CalculationT &&calculate)
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
Attribute constFoldBinaryOp(ArrayRef< Attribute > operands, Type resultType, CalculationT &&calculate)
detail::constant_int_range_predicate_matcher m_IntRangeWithoutNegOneS()
Matches a constant scalar / vector splat / tensor splat integer or a signed integer range that does n...
Definition Matchers.h:471
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
detail::constant_float_predicate_matcher m_PosZeroFloat()
Matches a constant scalar / vector splat / tensor splat float positive zero.
Definition Matchers.h:404
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Definition Matchers.h:442
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
detail::constant_float_predicate_matcher m_AnyZeroFloat()
Matches a constant scalar / vector splat / tensor splat float (both positive and negative) zero.
Definition Matchers.h:399
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
Definition Matchers.h:478
detail::constant_float_predicate_matcher m_NegInfFloat()
Matches a constant scalar / vector splat / tensor splat float negative infinity.
Definition Matchers.h:435
detail::constant_float_predicate_matcher m_NegZeroFloat()
Matches a constant scalar / vector splat / tensor splat float negative zero.
Definition Matchers.h:409
detail::constant_int_range_predicate_matcher m_IntRangeWithoutZeroS()
Matches a constant scalar / vector splat / tensor splat integer or a signed integer range that does n...
Definition Matchers.h:462
detail::op_matcher< OpClass > m_Op()
Matches the given OpClass.
Definition Matchers.h:484
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
Attribute constFoldUnaryOp(ArrayRef< Attribute > operands, Type resultType, CalculationT &&calculate)
detail::constant_float_predicate_matcher m_PosInfFloat()
Matches a constant scalar / vector splat / tensor splat float positive infinity.
Definition Matchers.h:427
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152
detail::constant_float_predicate_matcher m_OneFloat()
Matches a constant scalar / vector splat / tensor splat float ones.
Definition Matchers.h:414
detail::constant_int_range_predicate_matcher m_IntRangeWithoutZeroU()
Matches a constant scalar / vector splat / tensor splat integer or a unsigned integer range that does...
Definition Matchers.h:455
LogicalResult matchAndRewrite(arith::SelectOp op, PatternRewriter &rewriter) const override
OpRewritePattern Base
Type alias to allow derived classes to inherit constructors with using Base::Base;.
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
This represents an operation in an abstracted form, suitable for use with the builder APIs.