17 #include "llvm/ADT/SmallString.h"
18 #include "llvm/ADT/TypeSwitch.h"
27 void IndexDialect::registerOperations() {
30 #include "mlir/Dialect/Index/IR/IndexOps.cpp.inc"
37 if (
auto boolValue = dyn_cast<BoolAttr>(value)) {
40 return b.
create<BoolConstantOp>(loc, type, boolValue);
44 if (
auto indexValue = dyn_cast<IntegerAttr>(value)) {
45 if (!llvm::isa<IndexType>(indexValue.getType()) ||
46 !llvm::isa<IndexType>(type))
48 assert(indexValue.getValue().getBitWidth() ==
49 IndexType::kInternalStorageBitWidth);
50 return b.
create<ConstantOp>(loc, indexValue);
71 function_ref<std::optional<APInt>(
const APInt &,
const APInt &)>
73 assert(operands.size() == 2 &&
"binary operation expected 2 operands");
74 auto lhs = dyn_cast_if_present<IntegerAttr>(operands[0]);
75 auto rhs = dyn_cast_if_present<IntegerAttr>(operands[1]);
79 std::optional<APInt> result = calculate(lhs.getValue(), rhs.getValue());
82 assert(result->trunc(32) ==
83 calculate(lhs.getValue().trunc(32), rhs.getValue().trunc(32)));
97 function_ref<std::optional<APInt>(
const APInt &,
const APInt &lhs)>
99 assert(operands.size() == 2 &&
"binary operation expected 2 operands");
100 auto lhs = dyn_cast_if_present<IntegerAttr>(operands[0]);
101 auto rhs = dyn_cast_if_present<IntegerAttr>(operands[1]);
107 std::optional<APInt> result64 = calculate(lhs.getValue(), rhs.getValue());
110 std::optional<APInt> result32 =
111 calculate(lhs.getValue().trunc(32), rhs.getValue().trunc(32));
115 if (result64->trunc(32) != *result32)
127 adaptor.getOperands(),
128 [](
const APInt &lhs,
const APInt &rhs) { return lhs + rhs; }))
131 if (
auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) {
133 if (rhs.getValue().isZero())
146 adaptor.getOperands(),
147 [](
const APInt &lhs,
const APInt &rhs) { return lhs - rhs; }))
150 if (
auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) {
152 if (rhs.getValue().isZero())
165 adaptor.getOperands(),
166 [](
const APInt &lhs,
const APInt &rhs) { return lhs * rhs; }))
169 if (
auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) {
171 if (rhs.getValue().isOne())
174 if (rhs.getValue().isZero())
187 adaptor.getOperands(),
188 [](
const APInt &lhs,
const APInt &rhs) -> std::optional<APInt> {
192 return lhs.sdiv(rhs);
202 adaptor.getOperands(),
203 [](
const APInt &lhs,
const APInt &rhs) -> std::optional<APInt> {
207 return lhs.udiv(rhs);
225 bool mGtZ = m.sgt(0);
226 if (n.sgt(0) != mGtZ) {
230 return -(-n).sdiv(m);
234 int64_t x = mGtZ ? -1 : 1;
235 return (n + x).sdiv(m) + 1;
249 adaptor.getOperands(),
250 [](
const APInt &n,
const APInt &m) -> std::optional<APInt> {
258 return (n - 1).udiv(m) + 1;
276 bool mLtZ = m.slt(0);
277 if (n.slt(0) == mLtZ) {
284 int64_t x = mLtZ ? 1 : -1;
285 return -1 - (x - n).sdiv(m);
298 adaptor.getOperands(),
299 [](
const APInt &lhs,
const APInt &rhs) -> std::optional<APInt> {
303 return lhs.srem(rhs);
313 adaptor.getOperands(),
314 [](
const APInt &lhs,
const APInt &rhs) -> std::optional<APInt> {
318 return lhs.urem(rhs);
328 [](
const APInt &lhs,
const APInt &rhs) {
329 return lhs.sgt(rhs) ? lhs : rhs;
339 [](
const APInt &lhs,
const APInt &rhs) {
340 return lhs.ugt(rhs) ? lhs : rhs;
350 [](
const APInt &lhs,
const APInt &rhs) {
351 return lhs.slt(rhs) ? lhs : rhs;
361 [](
const APInt &lhs,
const APInt &rhs) {
362 return lhs.ult(rhs) ? lhs : rhs;
372 adaptor.getOperands(),
373 [](
const APInt &lhs,
const APInt &rhs) -> std::optional<APInt> {
389 adaptor.getOperands(),
390 [](
const APInt &lhs,
const APInt &rhs) -> std::optional<APInt> {
394 return lhs.ashr(rhs);
404 adaptor.getOperands(),
405 [](
const APInt &lhs,
const APInt &rhs) -> std::optional<APInt> {
409 return lhs.lshr(rhs);
419 adaptor.getOperands(),
420 [](
const APInt &lhs,
const APInt &rhs) { return lhs & rhs; });
429 adaptor.getOperands(),
430 [](
const APInt &lhs,
const APInt &rhs) { return lhs | rhs; });
439 adaptor.getOperands(),
440 [](
const APInt &lhs,
const APInt &rhs) { return lhs ^ rhs; });
450 function_ref<APInt(
const APInt &,
unsigned)> extOrTruncFn) {
451 auto attr = dyn_cast_if_present<IntegerAttr>(input);
454 const APInt &value = attr.getValue();
456 if (isa<IndexType>(type)) {
460 APInt result = extOrTruncFn(value, 64);
466 auto intType = cast<IntegerType>(type);
467 unsigned width = intType.getWidth();
472 APInt result = value.trunc(width);
479 if (extFn(value.trunc(32), 64) != value)
481 APInt result = extFn(value, width);
486 APInt result = value.trunc(width);
487 if (result != extFn(value.trunc(32), width))
493 return llvm::isa<IndexType>(lhsTypes.front()) !=
494 llvm::isa<IndexType>(rhsTypes.front());
499 adaptor.getInput(), getType(),
500 [](
const APInt &x,
unsigned width) {
return x.sext(width); },
501 [](
const APInt &x,
unsigned width) {
return x.sextOrTrunc(width); });
509 return llvm::isa<IndexType>(lhsTypes.front()) !=
510 llvm::isa<IndexType>(rhsTypes.front());
515 adaptor.getInput(), getType(),
516 [](
const APInt &x,
unsigned width) {
return x.zext(width); },
517 [](
const APInt &x,
unsigned width) {
return x.zextOrTrunc(width); });
526 IndexCmpPredicate pred) {
528 case IndexCmpPredicate::EQ:
530 case IndexCmpPredicate::NE:
532 case IndexCmpPredicate::SGE:
534 case IndexCmpPredicate::SGT:
536 case IndexCmpPredicate::SLE:
538 case IndexCmpPredicate::SLT:
540 case IndexCmpPredicate::UGE:
542 case IndexCmpPredicate::UGT:
544 case IndexCmpPredicate::ULE:
546 case IndexCmpPredicate::ULT:
549 llvm_unreachable(
"unhandled IndexCmpPredicate predicate");
558 const APInt &cstB,
unsigned width,
559 IndexCmpPredicate pred) {
561 .Case([&](MinSOp op) {
562 return ConstantIntRanges::fromSigned(
563 APInt::getSignedMinValue(width), cstA);
565 .Case([&](MinUOp op) {
566 return ConstantIntRanges::fromUnsigned(
567 APInt::getMinValue(width), cstA);
569 .Case([&](MaxSOp op) {
570 return ConstantIntRanges::fromSigned(
571 cstA, APInt::getSignedMaxValue(width));
573 .Case([&](MaxUOp op) {
574 return ConstantIntRanges::fromUnsigned(
575 cstA, APInt::getMaxValue(width));
578 lhsRange, ConstantIntRanges::constant(cstB));
584 case IndexCmpPredicate::EQ:
585 case IndexCmpPredicate::SGE:
586 case IndexCmpPredicate::SLE:
587 case IndexCmpPredicate::UGE:
588 case IndexCmpPredicate::ULE:
590 case IndexCmpPredicate::NE:
591 case IndexCmpPredicate::SGT:
592 case IndexCmpPredicate::SLT:
593 case IndexCmpPredicate::UGT:
594 case IndexCmpPredicate::ULT:
601 auto lhs = dyn_cast_if_present<IntegerAttr>(adaptor.getLhs());
602 auto rhs = dyn_cast_if_present<IntegerAttr>(adaptor.getRhs());
605 bool result64 =
compareIndices(lhs.getValue(), rhs.getValue(), getPred());
607 rhs.getValue().trunc(32), getPred());
608 if (result64 == result32)
613 Operation *lhsOp = getLhs().getDefiningOp();
615 if (isa_and_nonnull<MinSOp, MinUOp, MaxSOp, MaxUOp>(lhsOp) &&
618 lhsOp, cstA.getValue(), rhs.getValue(), 64, getPred());
619 std::optional<bool> result32 =
621 rhs.getValue().trunc(32), 32, getPred());
623 if (result64 && result32 && *result64 == *result32)
628 if (getLhs() == getRhs())
642 cmpRhs.getValue().isZero();
644 cmpLhs.getValue().isZero();
645 if (!rhsIsZero && !lhsIsZero)
647 "cmp is not comparing something with 0");
648 SubOp subOp = rhsIsZero ? op.getLhs().getDefiningOp<index::SubOp>()
649 : op.getRhs().getDefiningOp<index::SubOp>();
652 op.
getLoc(),
"non-zero operand is not a result of subtraction");
656 newCmp = rewriter.
create<index::CmpOp>(op.
getLoc(), op.getPred(),
657 subOp.getLhs(), subOp.getRhs());
659 newCmp = rewriter.
create<index::CmpOp>(op.
getLoc(), op.getPred(),
660 subOp.getRhs(), subOp.getLhs());
669 void ConstantOp::getAsmResultNames(
672 llvm::raw_svector_ostream specialName(specialNameBuffer);
673 specialName <<
"idx" << getValueAttr().getValue();
674 setNameFn(getResult(), specialName.str());
677 OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) {
return getValueAttr(); }
687 OpFoldResult BoolConstantOp::fold(FoldAdaptor adaptor) {
688 return getValueAttr();
691 void BoolConstantOp::getAsmResultNames(
693 setNameFn(getResult(), getValue() ?
"true" :
"false");
700 #define GET_OP_CLASSES
701 #include "mlir/Dialect/Index/IR/IndexOps.cpp.inc"
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
static OpFoldResult foldBinaryOpUnchecked(ArrayRef< Attribute > operands, function_ref< std::optional< APInt >(const APInt &, const APInt &)> calculate)
Fold an index operation irrespective of the target bitwidth.
bool compareIndices(const APInt &lhs, const APInt &rhs, IndexCmpPredicate pred)
Compare two integers according to the comparison predicate.
static OpFoldResult foldBinaryOpChecked(ArrayRef< Attribute > operands, function_ref< std::optional< APInt >(const APInt &, const APInt &lhs)> calculate)
Fold an index operation only if the truncated 64-bit result matches the 32-bit result for operations ...
static std::optional< bool > foldCmpOfMaxOrMin(Operation *lhsOp, const APInt &cstA, const APInt &cstB, unsigned width, IndexCmpPredicate pred)
cmp(max/min(x, cstA), cstB) can be folded to a constant depending on the values of cstA and cstB,...
static OpFoldResult foldCastOp(Attribute input, Type type, function_ref< APInt(const APInt &, unsigned)> extFn, function_ref< APInt(const APInt &, unsigned)> extOrTruncFn)
static std::optional< APInt > calculateCeilDivS(const APInt &n, const APInt &m)
Compute ceildivs(n, m) as x = m > 0 ? -1 : 1 and then n*m > 0 ? (n+x)/m + 1 : -(-n/m).
static bool compareSameArgs(IndexCmpPredicate pred)
Return the result of cmp(pred, x, x)
static std::optional< APInt > calculateFloorDivS(const APInt &n, const APInt &m)
Compute floordivs(n, m) as x = m < 0 ? 1 : -1 and then n*m < 0 ? -1 - (x-n)/m : n/m.
static MLIRContext * getContext(OpFoldResult val)
Attributes are known-constant values of operations.
IntegerAttr getIndexAttr(int64_t value)
A set of arbitrary-precision integers representing bounds on a given integer value.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
Location getLoc()
The source location the operation was defined or derived from.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isSignlessInteger() const
Return true if this is a signless integer type (with the specified width).
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
std::optional< bool > evaluatePred(CmpPredicate pred, const ConstantIntRanges &lhs, const ConstantIntRanges &rhs)
Returns a boolean value if pred is statically true or false for anypossible inputs falling within lhs...
CmpPredicate
Copy of the enum from arith and index to allow the common integer range infrastructure to not depend ...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
This class represents an efficient way to signal success or failure.
This represents an operation in an abstracted form, suitable for use with the builder APIs.