16#include "llvm/ADT/SmallString.h"
17#include "llvm/ADT/TypeSwitch.h"
26void IndexDialect::registerOperations() {
29#include "mlir/Dialect/Index/IR/IndexOps.cpp.inc"
36 if (
auto boolValue = dyn_cast<BoolAttr>(value)) {
39 return BoolConstantOp::create(
b, loc, type, boolValue);
43 if (
auto indexValue = dyn_cast<IntegerAttr>(value)) {
44 if (!llvm::isa<IndexType>(indexValue.getType()) ||
45 !llvm::isa<IndexType>(type))
47 assert(indexValue.getValue().getBitWidth() ==
48 IndexType::kInternalStorageBitWidth);
49 return ConstantOp::create(
b, loc, indexValue);
70 function_ref<std::optional<APInt>(
const APInt &,
const APInt &)>
72 assert(operands.size() == 2 &&
"binary operation expected 2 operands");
73 auto lhs = dyn_cast_if_present<IntegerAttr>(operands[0]);
74 auto rhs = dyn_cast_if_present<IntegerAttr>(operands[1]);
78 std::optional<APInt>
result = calculate(
lhs.getValue(),
rhs.getValue());
81 assert(
result->trunc(32) ==
82 calculate(
lhs.getValue().trunc(32),
rhs.getValue().trunc(32)));
83 return IntegerAttr::get(IndexType::get(
lhs.getContext()), *
result);
98 assert(operands.size() == 2 &&
"binary operation expected 2 operands");
99 auto lhs = dyn_cast_if_present<IntegerAttr>(operands[0]);
100 auto rhs = dyn_cast_if_present<IntegerAttr>(operands[1]);
106 std::optional<APInt> result64 = calculate(
lhs.getValue(),
rhs.getValue());
109 std::optional<APInt> result32 =
110 calculate(
lhs.getValue().trunc(32),
rhs.getValue().trunc(32));
114 if (result64->trunc(32) != *result32)
117 return IntegerAttr::get(IndexType::get(
lhs.getContext()), *result64);
123template <
typename BinaryOp>
130 auto lhsOp = op.getLhs().template getDefiningOp<BinaryOp>();
152 adaptor.getOperands(),
153 [](
const APInt &
lhs,
const APInt &
rhs) { return lhs + rhs; }))
156 if (
auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) {
158 if (
rhs.getValue().isZero())
165LogicalResult AddOp::canonicalize(AddOp op,
PatternRewriter &rewriter) {
175 adaptor.getOperands(),
176 [](
const APInt &
lhs,
const APInt &
rhs) { return lhs - rhs; }))
179 if (
auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) {
181 if (
rhs.getValue().isZero())
194 adaptor.getOperands(),
195 [](
const APInt &
lhs,
const APInt &
rhs) { return lhs * rhs; }))
198 if (
auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) {
200 if (
rhs.getValue().isOne())
203 if (
rhs.getValue().isZero())
210LogicalResult MulOp::canonicalize(MulOp op,
PatternRewriter &rewriter) {
220 adaptor.getOperands(),
221 [](
const APInt &
lhs,
const APInt &
rhs) -> std::optional<APInt> {
225 return lhs.sdiv(rhs);
235 adaptor.getOperands(),
236 [](
const APInt &
lhs,
const APInt &
rhs) -> std::optional<APInt> {
240 return lhs.udiv(rhs);
258 bool mGtZ = m.sgt(0);
259 if (n.sgt(0) != mGtZ) {
263 return -(-n).sdiv(m);
268 return (n + x).sdiv(m) + 1;
282 adaptor.getOperands(),
283 [](
const APInt &n,
const APInt &m) -> std::optional<APInt> {
291 return (n - 1).udiv(m) + 1;
309 bool mLtZ = m.slt(0);
310 if (n.slt(0) == mLtZ) {
318 return -1 - (x - n).sdiv(m);
331 adaptor.getOperands(),
332 [](
const APInt &
lhs,
const APInt &
rhs) -> std::optional<APInt> {
336 return lhs.srem(rhs);
346 adaptor.getOperands(),
347 [](
const APInt &
lhs,
const APInt &
rhs) -> std::optional<APInt> {
351 return lhs.urem(rhs);
361 [](
const APInt &
lhs,
const APInt &
rhs) {
362 return lhs.sgt(rhs) ? lhs : rhs;
366LogicalResult MaxSOp::canonicalize(MaxSOp op,
PatternRewriter &rewriter) {
376 [](
const APInt &
lhs,
const APInt &
rhs) {
377 return lhs.ugt(rhs) ? lhs : rhs;
381LogicalResult MaxUOp::canonicalize(MaxUOp op,
PatternRewriter &rewriter) {
391 [](
const APInt &
lhs,
const APInt &
rhs) {
392 return lhs.slt(rhs) ? lhs : rhs;
396LogicalResult MinSOp::canonicalize(MinSOp op,
PatternRewriter &rewriter) {
406 [](
const APInt &
lhs,
const APInt &
rhs) {
407 return lhs.ult(rhs) ? lhs : rhs;
411LogicalResult MinUOp::canonicalize(MinUOp op,
PatternRewriter &rewriter) {
421 adaptor.getOperands(),
422 [](
const APInt &
lhs,
const APInt &
rhs) -> std::optional<APInt> {
438 adaptor.getOperands(),
439 [](
const APInt &
lhs,
const APInt &
rhs) -> std::optional<APInt> {
453 adaptor.getOperands(),
454 [](
const APInt &
lhs,
const APInt &
rhs) -> std::optional<APInt> {
468 adaptor.getOperands(),
469 [](
const APInt &
lhs,
const APInt &
rhs) { return lhs & rhs; });
472LogicalResult AndOp::canonicalize(AndOp op,
PatternRewriter &rewriter) {
482 adaptor.getOperands(),
483 [](
const APInt &
lhs,
const APInt &
rhs) { return lhs | rhs; });
496 adaptor.getOperands(),
497 [](
const APInt &
lhs,
const APInt &
rhs) { return lhs ^ rhs; });
500LogicalResult XOrOp::canonicalize(XOrOp op,
PatternRewriter &rewriter) {
511 function_ref<APInt(
const APInt &,
unsigned)> extOrTruncFn) {
512 auto attr = dyn_cast_if_present<IntegerAttr>(input);
515 const APInt &value = attr.getValue();
517 if (isa<IndexType>(type)) {
521 APInt
result = extOrTruncFn(value, 64);
522 return IntegerAttr::get(type,
result);
527 auto intType = cast<IntegerType>(type);
528 unsigned width = intType.getWidth();
533 APInt
result = value.trunc(width);
534 return IntegerAttr::get(type,
result);
540 if (extFn(value.trunc(32), 64) != value)
542 APInt
result = extFn(value, width);
543 return IntegerAttr::get(type,
result);
547 APInt
result = value.trunc(width);
548 if (
result != extFn(value.trunc(32), width))
550 return IntegerAttr::get(type,
result);
554 return llvm::isa<IndexType>(lhsTypes.front()) !=
555 llvm::isa<IndexType>(rhsTypes.front());
561 [](
const APInt &x,
unsigned width) {
return x.sext(width); },
562 [](
const APInt &x,
unsigned width) {
return x.sextOrTrunc(width); });
570 return llvm::isa<IndexType>(lhsTypes.front()) !=
571 llvm::isa<IndexType>(rhsTypes.front());
577 [](
const APInt &x,
unsigned width) {
return x.zext(width); },
578 [](
const APInt &x,
unsigned width) {
return x.zextOrTrunc(width); });
587 IndexCmpPredicate pred) {
589 case IndexCmpPredicate::EQ:
591 case IndexCmpPredicate::NE:
593 case IndexCmpPredicate::SGE:
595 case IndexCmpPredicate::SGT:
597 case IndexCmpPredicate::SLE:
599 case IndexCmpPredicate::SLT:
601 case IndexCmpPredicate::UGE:
603 case IndexCmpPredicate::UGT:
605 case IndexCmpPredicate::ULE:
607 case IndexCmpPredicate::ULT:
610 llvm_unreachable(
"unhandled IndexCmpPredicate predicate");
619 const APInt &cstB,
unsigned width,
620 IndexCmpPredicate pred) {
622 .Case([&](MinSOp op) {
624 APInt::getSignedMinValue(width), cstA);
626 .Case([&](MinUOp op) {
628 APInt::getMinValue(width), cstA);
630 .Case([&](MaxSOp op) {
632 cstA, APInt::getSignedMaxValue(width));
634 .Case([&](MaxUOp op) {
636 cstA, APInt::getMaxValue(width));
645 case IndexCmpPredicate::EQ:
646 case IndexCmpPredicate::SGE:
647 case IndexCmpPredicate::SLE:
648 case IndexCmpPredicate::UGE:
649 case IndexCmpPredicate::ULE:
651 case IndexCmpPredicate::NE:
652 case IndexCmpPredicate::SGT:
653 case IndexCmpPredicate::SLT:
654 case IndexCmpPredicate::UGT:
655 case IndexCmpPredicate::ULT:
658 llvm_unreachable(
"unknown predicate in compareSameArgs");
663 auto lhs = dyn_cast_if_present<IntegerAttr>(adaptor.getLhs());
664 auto rhs = dyn_cast_if_present<IntegerAttr>(adaptor.getRhs());
669 rhs.getValue().trunc(32), getPred());
670 if (result64 == result32)
675 Operation *lhsOp = getLhs().getDefiningOp();
677 if (isa_and_nonnull<MinSOp, MinUOp, MaxSOp, MaxUOp>(lhsOp) &&
680 lhsOp, cstA.getValue(),
rhs.getValue(), 64, getPred());
681 std::optional<bool> result32 =
683 rhs.getValue().trunc(32), 32, getPred());
685 if (result64 && result32 && *result64 == *result32)
690 if (getLhs() == getRhs())
699LogicalResult CmpOp::canonicalize(CmpOp op,
PatternRewriter &rewriter) {
704 cmpRhs.getValue().isZero();
706 cmpLhs.getValue().isZero();
707 if (!rhsIsZero && !lhsIsZero)
709 "cmp is not comparing something with 0");
710 SubOp subOp = rhsIsZero ? op.getLhs().getDefiningOp<index::SubOp>()
711 : op.getRhs().getDefiningOp<
index::SubOp>();
714 op.getLoc(),
"non-zero operand is not a result of subtraction");
718 newCmp = index::CmpOp::create(rewriter, op.getLoc(), op.getPred(),
719 subOp.getLhs(), subOp.getRhs());
721 newCmp = index::CmpOp::create(rewriter, op.getLoc(), op.getPred(),
722 subOp.getRhs(), subOp.getLhs());
731void ConstantOp::getAsmResultNames(
734 llvm::raw_svector_ostream specialName(specialNameBuffer);
735 specialName <<
"idx" << getValueAttr().getValue();
736 setNameFn(getResult(), specialName.str());
739OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) {
return getValueAttr(); }
742 build(
b, state,
b.getIndexType(),
b.getIndexAttr(value));
750 return getValueAttr();
753void BoolConstantOp::getAsmResultNames(
755 setNameFn(getResult(), getValue() ?
"true" :
"false");
762#define GET_OP_CLASSES
763#include "mlir/Dialect/Index/IR/IndexOps.cpp.inc"
static OpFoldResult foldBinaryOpUnchecked(ArrayRef< Attribute > operands, function_ref< std::optional< APInt >(const APInt &, const APInt &)> calculate)
Fold an index operation irrespective of the target bitwidth.
static std::optional< APInt > calculateFloorDivS(const APInt &n, const APInt &m)
Compute floordivs(n, m) as x = m < 0 ?
static std::optional< APInt > calculateCeilDivS(const APInt &n, const APInt &m)
Compute ceildivs(n, m) as x = m > 0 ?
static std::optional< bool > foldCmpOfMaxOrMin(Operation *lhsOp, const APInt &cstA, const APInt &cstB, unsigned width, IndexCmpPredicate pred)
cmp(max/min(x, cstA), cstB) can be folded to a constant depending on the values of cstA and cstB,...
LogicalResult canonicalizeAssociativeCommutativeBinaryOp(BinaryOp op, PatternRewriter &rewriter)
Helper for associative and commutative binary ops that can be transformed: x = op(v,...
bool compareIndices(const APInt &lhs, const APInt &rhs, IndexCmpPredicate pred)
Compare two integers according to the comparison predicate.
static OpFoldResult foldBinaryOpChecked(ArrayRef< Attribute > operands, function_ref< std::optional< APInt >(const APInt &, const APInt &lhs)> calculate)
Fold an index operation only if the truncated 64-bit result matches the 32-bit result for operations ...
static OpFoldResult foldCastOp(Attribute input, Type type, function_ref< APInt(const APInt &, unsigned)> extFn, function_ref< APInt(const APInt &, unsigned)> extOrTruncFn)
static bool compareSameArgs(IndexCmpPredicate pred)
Return the result of cmp(pred, x, x)
Attributes are known-constant values of operations.
static BoolAttr get(MLIRContext *context, bool value)
A set of arbitrary-precision integers representing bounds on a given integer value.
static ConstantIntRanges constant(const APInt &value)
Create a ConstantIntRanges with a constant value - that is, with the bounds [value,...
static ConstantIntRanges fromUnsigned(const APInt &umin, const APInt &umax)
Create an ConstantIntRanges with the unsigned minimum and maximum equal to umin and umax and the sign...
static ConstantIntRanges fromSigned(const APInt &smin, const APInt &smax)
Create an ConstantIntRanges with the signed minimum and maximum equal to smin and smax,...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isSignlessInteger() const
Return true if this is a signless integer type (with the specified width).
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
std::optional< bool > evaluatePred(CmpPredicate pred, const ConstantIntRanges &lhs, const ConstantIntRanges &rhs)
Returns a boolean value if pred is statically true or false for anypossible inputs falling within lhs...
CmpPredicate
Copy of the enum from arith and index to allow the common integer range infrastructure to not depend ...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
llvm::TypeSwitch< T, ResultT > TypeSwitch
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
llvm::function_ref< Fn > function_ref
This represents an operation in an abstracted form, suitable for use with the builder APIs.