16 #include "llvm/ADT/SmallString.h"
17 #include "llvm/ADT/TypeSwitch.h"
26 void IndexDialect::registerOperations() {
29 #include "mlir/Dialect/Index/IR/IndexOps.cpp.inc"
36 if (
auto boolValue = dyn_cast<BoolAttr>(value)) {
39 return BoolConstantOp::create(b, loc, type, boolValue);
43 if (
auto indexValue = dyn_cast<IntegerAttr>(value)) {
44 if (!llvm::isa<IndexType>(indexValue.getType()) ||
45 !llvm::isa<IndexType>(type))
47 assert(indexValue.getValue().getBitWidth() ==
48 IndexType::kInternalStorageBitWidth);
49 return ConstantOp::create(b, loc, indexValue);
70 function_ref<std::optional<APInt>(
const APInt &,
const APInt &)>
72 assert(operands.size() == 2 &&
"binary operation expected 2 operands");
73 auto lhs = dyn_cast_if_present<IntegerAttr>(operands[0]);
74 auto rhs = dyn_cast_if_present<IntegerAttr>(operands[1]);
78 std::optional<APInt> result = calculate(lhs.getValue(), rhs.getValue());
81 assert(result->trunc(32) ==
82 calculate(lhs.getValue().trunc(32), rhs.getValue().trunc(32)));
96 function_ref<std::optional<APInt>(
const APInt &,
const APInt &lhs)>
98 assert(operands.size() == 2 &&
"binary operation expected 2 operands");
99 auto lhs = dyn_cast_if_present<IntegerAttr>(operands[0]);
100 auto rhs = dyn_cast_if_present<IntegerAttr>(operands[1]);
106 std::optional<APInt> result64 = calculate(lhs.getValue(), rhs.getValue());
109 std::optional<APInt> result32 =
110 calculate(lhs.getValue().trunc(32), rhs.getValue().trunc(32));
114 if (result64->trunc(32) != *result32)
123 template <
typename BinaryOp>
130 auto lhsOp = op.getLhs().template getDefiningOp<BinaryOp>();
152 adaptor.getOperands(),
153 [](
const APInt &lhs,
const APInt &rhs) { return lhs + rhs; }))
156 if (
auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) {
158 if (rhs.getValue().isZero())
165 LogicalResult AddOp::canonicalize(AddOp op,
PatternRewriter &rewriter) {
175 adaptor.getOperands(),
176 [](
const APInt &lhs,
const APInt &rhs) { return lhs - rhs; }))
179 if (
auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) {
181 if (rhs.getValue().isZero())
194 adaptor.getOperands(),
195 [](
const APInt &lhs,
const APInt &rhs) { return lhs * rhs; }))
198 if (
auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) {
200 if (rhs.getValue().isOne())
203 if (rhs.getValue().isZero())
210 LogicalResult MulOp::canonicalize(MulOp op,
PatternRewriter &rewriter) {
220 adaptor.getOperands(),
221 [](
const APInt &lhs,
const APInt &rhs) -> std::optional<APInt> {
225 return lhs.sdiv(rhs);
235 adaptor.getOperands(),
236 [](
const APInt &lhs,
const APInt &rhs) -> std::optional<APInt> {
240 return lhs.udiv(rhs);
258 bool mGtZ = m.sgt(0);
259 if (n.sgt(0) != mGtZ) {
263 return -(-n).sdiv(m);
267 int64_t x = mGtZ ? -1 : 1;
268 return (n + x).sdiv(m) + 1;
282 adaptor.getOperands(),
283 [](
const APInt &n,
const APInt &m) -> std::optional<APInt> {
291 return (n - 1).udiv(m) + 1;
309 bool mLtZ = m.slt(0);
310 if (n.slt(0) == mLtZ) {
317 int64_t x = mLtZ ? 1 : -1;
318 return -1 - (x - n).sdiv(m);
331 adaptor.getOperands(),
332 [](
const APInt &lhs,
const APInt &rhs) -> std::optional<APInt> {
336 return lhs.srem(rhs);
346 adaptor.getOperands(),
347 [](
const APInt &lhs,
const APInt &rhs) -> std::optional<APInt> {
351 return lhs.urem(rhs);
361 [](
const APInt &lhs,
const APInt &rhs) {
362 return lhs.sgt(rhs) ? lhs : rhs;
366 LogicalResult MaxSOp::canonicalize(MaxSOp op,
PatternRewriter &rewriter) {
376 [](
const APInt &lhs,
const APInt &rhs) {
377 return lhs.ugt(rhs) ? lhs : rhs;
381 LogicalResult MaxUOp::canonicalize(MaxUOp op,
PatternRewriter &rewriter) {
391 [](
const APInt &lhs,
const APInt &rhs) {
392 return lhs.slt(rhs) ? lhs : rhs;
396 LogicalResult MinSOp::canonicalize(MinSOp op,
PatternRewriter &rewriter) {
406 [](
const APInt &lhs,
const APInt &rhs) {
407 return lhs.ult(rhs) ? lhs : rhs;
411 LogicalResult MinUOp::canonicalize(MinUOp op,
PatternRewriter &rewriter) {
421 adaptor.getOperands(),
422 [](
const APInt &lhs,
const APInt &rhs) -> std::optional<APInt> {
438 adaptor.getOperands(),
439 [](
const APInt &lhs,
const APInt &rhs) -> std::optional<APInt> {
443 return lhs.ashr(rhs);
453 adaptor.getOperands(),
454 [](
const APInt &lhs,
const APInt &rhs) -> std::optional<APInt> {
458 return lhs.lshr(rhs);
468 adaptor.getOperands(),
469 [](
const APInt &lhs,
const APInt &rhs) { return lhs & rhs; });
472 LogicalResult AndOp::canonicalize(AndOp op,
PatternRewriter &rewriter) {
482 adaptor.getOperands(),
483 [](
const APInt &lhs,
const APInt &rhs) { return lhs | rhs; });
496 adaptor.getOperands(),
497 [](
const APInt &lhs,
const APInt &rhs) { return lhs ^ rhs; });
500 LogicalResult XOrOp::canonicalize(XOrOp op,
PatternRewriter &rewriter) {
511 function_ref<APInt(
const APInt &,
unsigned)> extOrTruncFn) {
512 auto attr = dyn_cast_if_present<IntegerAttr>(input);
515 const APInt &value = attr.getValue();
517 if (isa<IndexType>(type)) {
521 APInt result = extOrTruncFn(value, 64);
527 auto intType = cast<IntegerType>(type);
528 unsigned width = intType.getWidth();
533 APInt result = value.trunc(width);
540 if (extFn(value.trunc(32), 64) != value)
542 APInt result = extFn(value, width);
547 APInt result = value.trunc(width);
548 if (result != extFn(value.trunc(32), width))
554 return llvm::isa<IndexType>(lhsTypes.front()) !=
555 llvm::isa<IndexType>(rhsTypes.front());
561 [](
const APInt &x,
unsigned width) {
return x.sext(width); },
562 [](
const APInt &x,
unsigned width) {
return x.sextOrTrunc(width); });
570 return llvm::isa<IndexType>(lhsTypes.front()) !=
571 llvm::isa<IndexType>(rhsTypes.front());
577 [](
const APInt &x,
unsigned width) {
return x.zext(width); },
578 [](
const APInt &x,
unsigned width) {
return x.zextOrTrunc(width); });
587 IndexCmpPredicate pred) {
589 case IndexCmpPredicate::EQ:
591 case IndexCmpPredicate::NE:
593 case IndexCmpPredicate::SGE:
595 case IndexCmpPredicate::SGT:
597 case IndexCmpPredicate::SLE:
599 case IndexCmpPredicate::SLT:
601 case IndexCmpPredicate::UGE:
603 case IndexCmpPredicate::UGT:
605 case IndexCmpPredicate::ULE:
607 case IndexCmpPredicate::ULT:
610 llvm_unreachable(
"unhandled IndexCmpPredicate predicate");
619 const APInt &cstB,
unsigned width,
620 IndexCmpPredicate pred) {
622 .Case([&](MinSOp op) {
623 return ConstantIntRanges::fromSigned(
624 APInt::getSignedMinValue(width), cstA);
626 .Case([&](MinUOp op) {
627 return ConstantIntRanges::fromUnsigned(
628 APInt::getMinValue(width), cstA);
630 .Case([&](MaxSOp op) {
631 return ConstantIntRanges::fromSigned(
632 cstA, APInt::getSignedMaxValue(width));
634 .Case([&](MaxUOp op) {
635 return ConstantIntRanges::fromUnsigned(
636 cstA, APInt::getMaxValue(width));
639 lhsRange, ConstantIntRanges::constant(cstB));
645 case IndexCmpPredicate::EQ:
646 case IndexCmpPredicate::SGE:
647 case IndexCmpPredicate::SLE:
648 case IndexCmpPredicate::UGE:
649 case IndexCmpPredicate::ULE:
651 case IndexCmpPredicate::NE:
652 case IndexCmpPredicate::SGT:
653 case IndexCmpPredicate::SLT:
654 case IndexCmpPredicate::UGT:
655 case IndexCmpPredicate::ULT:
658 llvm_unreachable(
"unknown predicate in compareSameArgs");
663 auto lhs = dyn_cast_if_present<IntegerAttr>(adaptor.getLhs());
664 auto rhs = dyn_cast_if_present<IntegerAttr>(adaptor.getRhs());
667 bool result64 =
compareIndices(lhs.getValue(), rhs.getValue(), getPred());
669 rhs.getValue().trunc(32), getPred());
670 if (result64 == result32)
675 Operation *lhsOp = getLhs().getDefiningOp();
677 if (isa_and_nonnull<MinSOp, MinUOp, MaxSOp, MaxUOp>(lhsOp) &&
680 lhsOp, cstA.getValue(), rhs.getValue(), 64, getPred());
681 std::optional<bool> result32 =
683 rhs.getValue().trunc(32), 32, getPred());
685 if (result64 && result32 && *result64 == *result32)
690 if (getLhs() == getRhs())
699 LogicalResult CmpOp::canonicalize(CmpOp op,
PatternRewriter &rewriter) {
704 cmpRhs.getValue().isZero();
706 cmpLhs.getValue().isZero();
707 if (!rhsIsZero && !lhsIsZero)
709 "cmp is not comparing something with 0");
710 SubOp subOp = rhsIsZero ? op.getLhs().getDefiningOp<index::SubOp>()
711 : op.getRhs().getDefiningOp<index::SubOp>();
714 op.getLoc(),
"non-zero operand is not a result of subtraction");
718 newCmp = index::CmpOp::create(rewriter, op.getLoc(), op.getPred(),
719 subOp.getLhs(), subOp.getRhs());
721 newCmp = index::CmpOp::create(rewriter, op.getLoc(), op.getPred(),
722 subOp.getRhs(), subOp.getLhs());
731 void ConstantOp::getAsmResultNames(
734 llvm::raw_svector_ostream specialName(specialNameBuffer);
735 specialName <<
"idx" << getValueAttr().getValue();
736 setNameFn(getResult(), specialName.str());
739 OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) {
return getValueAttr(); }
749 OpFoldResult BoolConstantOp::fold(FoldAdaptor adaptor) {
750 return getValueAttr();
753 void BoolConstantOp::getAsmResultNames(
755 setNameFn(getResult(), getValue() ?
"true" :
"false");
762 #define GET_OP_CLASSES
763 #include "mlir/Dialect/Index/IR/IndexOps.cpp.inc"
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
static OpFoldResult foldBinaryOpUnchecked(ArrayRef< Attribute > operands, function_ref< std::optional< APInt >(const APInt &, const APInt &)> calculate)
Fold an index operation irrespective of the target bitwidth.
LogicalResult canonicalizeAssociativeCommutativeBinaryOp(BinaryOp op, PatternRewriter &rewriter)
Helper for associative and commutative binary ops that can be transformed: x = op(v,...
bool compareIndices(const APInt &lhs, const APInt &rhs, IndexCmpPredicate pred)
Compare two integers according to the comparison predicate.
static OpFoldResult foldBinaryOpChecked(ArrayRef< Attribute > operands, function_ref< std::optional< APInt >(const APInt &, const APInt &lhs)> calculate)
Fold an index operation only if the truncated 64-bit result matches the 32-bit result for operations ...
static std::optional< bool > foldCmpOfMaxOrMin(Operation *lhsOp, const APInt &cstA, const APInt &cstB, unsigned width, IndexCmpPredicate pred)
cmp(max/min(x, cstA), cstB) can be folded to a constant depending on the values of cstA and cstB,...
static OpFoldResult foldCastOp(Attribute input, Type type, function_ref< APInt(const APInt &, unsigned)> extFn, function_ref< APInt(const APInt &, unsigned)> extOrTruncFn)
static std::optional< APInt > calculateCeilDivS(const APInt &n, const APInt &m)
Compute ceildivs(n, m) as x = m > 0 ? -1 : 1 and then n*m > 0 ? (n+x)/m + 1 : -(-n/m).
static bool compareSameArgs(IndexCmpPredicate pred)
Return the result of cmp(pred, x, x)
static std::optional< APInt > calculateFloorDivS(const APInt &n, const APInt &m)
Compute floordivs(n, m) as x = m < 0 ? 1 : -1 and then n*m < 0 ? -1 - (x-n)/m : n/m.
static MLIRContext * getContext(OpFoldResult val)
Attributes are known-constant values of operations.
IntegerAttr getIndexAttr(int64_t value)
A set of arbitrary-precision integers representing bounds on a given integer value.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isSignlessInteger() const
Return true if this is a signless integer type (with the specified width).
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
std::optional< bool > evaluatePred(CmpPredicate pred, const ConstantIntRanges &lhs, const ConstantIntRanges &rhs)
Returns a boolean value if pred is statically true or false for anypossible inputs falling within lhs...
CmpPredicate
Copy of the enum from arith and index to allow the common integer range infrastructure to not depend ...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
This represents an operation in an abstracted form, suitable for use with the builder APIs.