17 #include "llvm/ADT/SmallString.h"
18 #include "llvm/ADT/TypeSwitch.h"
27 void IndexDialect::registerOperations() {
30 #include "mlir/Dialect/Index/IR/IndexOps.cpp.inc"
37 if (
auto boolValue = dyn_cast<BoolAttr>(value)) {
40 return b.
create<BoolConstantOp>(loc, type, boolValue);
44 if (
auto indexValue = dyn_cast<IntegerAttr>(value)) {
45 if (!llvm::isa<IndexType>(indexValue.getType()) ||
46 !llvm::isa<IndexType>(type))
48 assert(indexValue.getValue().getBitWidth() ==
49 IndexType::kInternalStorageBitWidth);
50 return b.
create<ConstantOp>(loc, indexValue);
71 function_ref<std::optional<APInt>(
const APInt &,
const APInt &)>
73 assert(operands.size() == 2 &&
"binary operation expected 2 operands");
74 auto lhs = dyn_cast_if_present<IntegerAttr>(operands[0]);
75 auto rhs = dyn_cast_if_present<IntegerAttr>(operands[1]);
79 std::optional<APInt> result = calculate(lhs.getValue(), rhs.getValue());
82 assert(result->trunc(32) ==
83 calculate(lhs.getValue().trunc(32), rhs.getValue().trunc(32)));
97 function_ref<std::optional<APInt>(
const APInt &,
const APInt &lhs)>
99 assert(operands.size() == 2 &&
"binary operation expected 2 operands");
100 auto lhs = dyn_cast_if_present<IntegerAttr>(operands[0]);
101 auto rhs = dyn_cast_if_present<IntegerAttr>(operands[1]);
107 std::optional<APInt> result64 = calculate(lhs.getValue(), rhs.getValue());
110 std::optional<APInt> result32 =
111 calculate(lhs.getValue().trunc(32), rhs.getValue().trunc(32));
115 if (result64->trunc(32) != *result32)
124 template <
typename BinaryOp>
131 auto lhsOp = op.getLhs().template getDefiningOp<BinaryOp>();
153 adaptor.getOperands(),
154 [](
const APInt &lhs,
const APInt &rhs) { return lhs + rhs; }))
157 if (
auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) {
159 if (rhs.getValue().isZero())
166 LogicalResult AddOp::canonicalize(AddOp op,
PatternRewriter &rewriter) {
176 adaptor.getOperands(),
177 [](
const APInt &lhs,
const APInt &rhs) { return lhs - rhs; }))
180 if (
auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) {
182 if (rhs.getValue().isZero())
195 adaptor.getOperands(),
196 [](
const APInt &lhs,
const APInt &rhs) { return lhs * rhs; }))
199 if (
auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) {
201 if (rhs.getValue().isOne())
204 if (rhs.getValue().isZero())
211 LogicalResult MulOp::canonicalize(MulOp op,
PatternRewriter &rewriter) {
221 adaptor.getOperands(),
222 [](
const APInt &lhs,
const APInt &rhs) -> std::optional<APInt> {
226 return lhs.sdiv(rhs);
236 adaptor.getOperands(),
237 [](
const APInt &lhs,
const APInt &rhs) -> std::optional<APInt> {
241 return lhs.udiv(rhs);
259 bool mGtZ = m.sgt(0);
260 if (n.sgt(0) != mGtZ) {
264 return -(-n).sdiv(m);
268 int64_t x = mGtZ ? -1 : 1;
269 return (n + x).sdiv(m) + 1;
283 adaptor.getOperands(),
284 [](
const APInt &n,
const APInt &m) -> std::optional<APInt> {
292 return (n - 1).udiv(m) + 1;
310 bool mLtZ = m.slt(0);
311 if (n.slt(0) == mLtZ) {
318 int64_t x = mLtZ ? 1 : -1;
319 return -1 - (x - n).sdiv(m);
332 adaptor.getOperands(),
333 [](
const APInt &lhs,
const APInt &rhs) -> std::optional<APInt> {
337 return lhs.srem(rhs);
347 adaptor.getOperands(),
348 [](
const APInt &lhs,
const APInt &rhs) -> std::optional<APInt> {
352 return lhs.urem(rhs);
362 [](
const APInt &lhs,
const APInt &rhs) {
363 return lhs.sgt(rhs) ? lhs : rhs;
367 LogicalResult MaxSOp::canonicalize(MaxSOp op,
PatternRewriter &rewriter) {
377 [](
const APInt &lhs,
const APInt &rhs) {
378 return lhs.ugt(rhs) ? lhs : rhs;
382 LogicalResult MaxUOp::canonicalize(MaxUOp op,
PatternRewriter &rewriter) {
392 [](
const APInt &lhs,
const APInt &rhs) {
393 return lhs.slt(rhs) ? lhs : rhs;
397 LogicalResult MinSOp::canonicalize(MinSOp op,
PatternRewriter &rewriter) {
407 [](
const APInt &lhs,
const APInt &rhs) {
408 return lhs.ult(rhs) ? lhs : rhs;
412 LogicalResult MinUOp::canonicalize(MinUOp op,
PatternRewriter &rewriter) {
422 adaptor.getOperands(),
423 [](
const APInt &lhs,
const APInt &rhs) -> std::optional<APInt> {
439 adaptor.getOperands(),
440 [](
const APInt &lhs,
const APInt &rhs) -> std::optional<APInt> {
444 return lhs.ashr(rhs);
454 adaptor.getOperands(),
455 [](
const APInt &lhs,
const APInt &rhs) -> std::optional<APInt> {
459 return lhs.lshr(rhs);
469 adaptor.getOperands(),
470 [](
const APInt &lhs,
const APInt &rhs) { return lhs & rhs; });
473 LogicalResult AndOp::canonicalize(AndOp op,
PatternRewriter &rewriter) {
483 adaptor.getOperands(),
484 [](
const APInt &lhs,
const APInt &rhs) { return lhs | rhs; });
497 adaptor.getOperands(),
498 [](
const APInt &lhs,
const APInt &rhs) { return lhs ^ rhs; });
501 LogicalResult XOrOp::canonicalize(XOrOp op,
PatternRewriter &rewriter) {
512 function_ref<APInt(
const APInt &,
unsigned)> extOrTruncFn) {
513 auto attr = dyn_cast_if_present<IntegerAttr>(input);
516 const APInt &value = attr.getValue();
518 if (isa<IndexType>(type)) {
522 APInt result = extOrTruncFn(value, 64);
528 auto intType = cast<IntegerType>(type);
529 unsigned width = intType.getWidth();
534 APInt result = value.trunc(width);
541 if (extFn(value.trunc(32), 64) != value)
543 APInt result = extFn(value, width);
548 APInt result = value.trunc(width);
549 if (result != extFn(value.trunc(32), width))
555 return llvm::isa<IndexType>(lhsTypes.front()) !=
556 llvm::isa<IndexType>(rhsTypes.front());
562 [](
const APInt &x,
unsigned width) {
return x.sext(width); },
563 [](
const APInt &x,
unsigned width) {
return x.sextOrTrunc(width); });
571 return llvm::isa<IndexType>(lhsTypes.front()) !=
572 llvm::isa<IndexType>(rhsTypes.front());
578 [](
const APInt &x,
unsigned width) {
return x.zext(width); },
579 [](
const APInt &x,
unsigned width) {
return x.zextOrTrunc(width); });
588 IndexCmpPredicate pred) {
590 case IndexCmpPredicate::EQ:
592 case IndexCmpPredicate::NE:
594 case IndexCmpPredicate::SGE:
596 case IndexCmpPredicate::SGT:
598 case IndexCmpPredicate::SLE:
600 case IndexCmpPredicate::SLT:
602 case IndexCmpPredicate::UGE:
604 case IndexCmpPredicate::UGT:
606 case IndexCmpPredicate::ULE:
608 case IndexCmpPredicate::ULT:
611 llvm_unreachable(
"unhandled IndexCmpPredicate predicate");
620 const APInt &cstB,
unsigned width,
621 IndexCmpPredicate pred) {
623 .Case([&](MinSOp op) {
624 return ConstantIntRanges::fromSigned(
625 APInt::getSignedMinValue(width), cstA);
627 .Case([&](MinUOp op) {
628 return ConstantIntRanges::fromUnsigned(
629 APInt::getMinValue(width), cstA);
631 .Case([&](MaxSOp op) {
632 return ConstantIntRanges::fromSigned(
633 cstA, APInt::getSignedMaxValue(width));
635 .Case([&](MaxUOp op) {
636 return ConstantIntRanges::fromUnsigned(
637 cstA, APInt::getMaxValue(width));
640 lhsRange, ConstantIntRanges::constant(cstB));
646 case IndexCmpPredicate::EQ:
647 case IndexCmpPredicate::SGE:
648 case IndexCmpPredicate::SLE:
649 case IndexCmpPredicate::UGE:
650 case IndexCmpPredicate::ULE:
652 case IndexCmpPredicate::NE:
653 case IndexCmpPredicate::SGT:
654 case IndexCmpPredicate::SLT:
655 case IndexCmpPredicate::UGT:
656 case IndexCmpPredicate::ULT:
659 llvm_unreachable(
"unknown predicate in compareSameArgs");
664 auto lhs = dyn_cast_if_present<IntegerAttr>(adaptor.getLhs());
665 auto rhs = dyn_cast_if_present<IntegerAttr>(adaptor.getRhs());
668 bool result64 =
compareIndices(lhs.getValue(), rhs.getValue(), getPred());
670 rhs.getValue().trunc(32), getPred());
671 if (result64 == result32)
676 Operation *lhsOp = getLhs().getDefiningOp();
678 if (isa_and_nonnull<MinSOp, MinUOp, MaxSOp, MaxUOp>(lhsOp) &&
681 lhsOp, cstA.getValue(), rhs.getValue(), 64, getPred());
682 std::optional<bool> result32 =
684 rhs.getValue().trunc(32), 32, getPred());
686 if (result64 && result32 && *result64 == *result32)
691 if (getLhs() == getRhs())
700 LogicalResult CmpOp::canonicalize(CmpOp op,
PatternRewriter &rewriter) {
705 cmpRhs.getValue().isZero();
707 cmpLhs.getValue().isZero();
708 if (!rhsIsZero && !lhsIsZero)
710 "cmp is not comparing something with 0");
711 SubOp subOp = rhsIsZero ? op.getLhs().getDefiningOp<index::SubOp>()
712 : op.getRhs().getDefiningOp<index::SubOp>();
715 op.getLoc(),
"non-zero operand is not a result of subtraction");
719 newCmp = rewriter.
create<index::CmpOp>(op.getLoc(), op.getPred(),
720 subOp.getLhs(), subOp.getRhs());
722 newCmp = rewriter.
create<index::CmpOp>(op.getLoc(), op.getPred(),
723 subOp.getRhs(), subOp.getLhs());
732 void ConstantOp::getAsmResultNames(
735 llvm::raw_svector_ostream specialName(specialNameBuffer);
736 specialName <<
"idx" << getValueAttr().getValue();
737 setNameFn(getResult(), specialName.str());
740 OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) {
return getValueAttr(); }
750 OpFoldResult BoolConstantOp::fold(FoldAdaptor adaptor) {
751 return getValueAttr();
754 void BoolConstantOp::getAsmResultNames(
756 setNameFn(getResult(), getValue() ?
"true" :
"false");
763 #define GET_OP_CLASSES
764 #include "mlir/Dialect/Index/IR/IndexOps.cpp.inc"
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
static OpFoldResult foldBinaryOpUnchecked(ArrayRef< Attribute > operands, function_ref< std::optional< APInt >(const APInt &, const APInt &)> calculate)
Fold an index operation irrespective of the target bitwidth.
LogicalResult canonicalizeAssociativeCommutativeBinaryOp(BinaryOp op, PatternRewriter &rewriter)
Helper for associative and commutative binary ops that can be transformed: x = op(v,...
bool compareIndices(const APInt &lhs, const APInt &rhs, IndexCmpPredicate pred)
Compare two integers according to the comparison predicate.
static OpFoldResult foldBinaryOpChecked(ArrayRef< Attribute > operands, function_ref< std::optional< APInt >(const APInt &, const APInt &lhs)> calculate)
Fold an index operation only if the truncated 64-bit result matches the 32-bit result for operations ...
static std::optional< bool > foldCmpOfMaxOrMin(Operation *lhsOp, const APInt &cstA, const APInt &cstB, unsigned width, IndexCmpPredicate pred)
cmp(max/min(x, cstA), cstB) can be folded to a constant depending on the values of cstA and cstB,...
static OpFoldResult foldCastOp(Attribute input, Type type, function_ref< APInt(const APInt &, unsigned)> extFn, function_ref< APInt(const APInt &, unsigned)> extOrTruncFn)
static std::optional< APInt > calculateCeilDivS(const APInt &n, const APInt &m)
Compute ceildivs(n, m) as x = m > 0 ? -1 : 1 and then n*m > 0 ? (n+x)/m + 1 : -(-n/m).
static bool compareSameArgs(IndexCmpPredicate pred)
Return the result of cmp(pred, x, x)
static std::optional< APInt > calculateFloorDivS(const APInt &n, const APInt &m)
Compute floordivs(n, m) as x = m < 0 ? 1 : -1 and then n*m < 0 ? -1 - (x-n)/m : n/m.
static MLIRContext * getContext(OpFoldResult val)
Attributes are known-constant values of operations.
IntegerAttr getIndexAttr(int64_t value)
A set of arbitrary-precision integers representing bounds on a given integer value.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isSignlessInteger() const
Return true if this is a signless integer type (with the specified width).
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
std::optional< bool > evaluatePred(CmpPredicate pred, const ConstantIntRanges &lhs, const ConstantIntRanges &rhs)
Returns a boolean value if pred is statically true or false for anypossible inputs falling within lhs...
CmpPredicate
Copy of the enum from arith and index to allow the common integer range infrastructure to not depend ...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
This represents an operation in an abstracted form, suitable for use with the builder APIs.