16 #include "llvm/ADT/SmallString.h" 
   17 #include "llvm/ADT/TypeSwitch.h" 
   26 void IndexDialect::registerOperations() {
 
   29 #include "mlir/Dialect/Index/IR/IndexOps.cpp.inc" 
   36   if (
auto boolValue = dyn_cast<BoolAttr>(value)) {
 
   39     return BoolConstantOp::create(b, loc, type, boolValue);
 
   43   if (
auto indexValue = dyn_cast<IntegerAttr>(value)) {
 
   44     if (!llvm::isa<IndexType>(indexValue.getType()) ||
 
   45         !llvm::isa<IndexType>(type))
 
   47     assert(indexValue.getValue().getBitWidth() ==
 
   48            IndexType::kInternalStorageBitWidth);
 
   49     return ConstantOp::create(b, loc, indexValue);
 
   70     function_ref<std::optional<APInt>(
const APInt &, 
const APInt &)>
 
   72   assert(operands.size() == 2 && 
"binary operation expected 2 operands");
 
   73   auto lhs = dyn_cast_if_present<IntegerAttr>(operands[0]);
 
   74   auto rhs = dyn_cast_if_present<IntegerAttr>(operands[1]);
 
   78   std::optional<APInt> result = calculate(lhs.getValue(), rhs.getValue());
 
   81   assert(result->trunc(32) ==
 
   82          calculate(lhs.getValue().trunc(32), rhs.getValue().trunc(32)));
 
   96     function_ref<std::optional<APInt>(
const APInt &, 
const APInt &lhs)>
 
   98   assert(operands.size() == 2 && 
"binary operation expected 2 operands");
 
   99   auto lhs = dyn_cast_if_present<IntegerAttr>(operands[0]);
 
  100   auto rhs = dyn_cast_if_present<IntegerAttr>(operands[1]);
 
  106   std::optional<APInt> result64 = calculate(lhs.getValue(), rhs.getValue());
 
  109   std::optional<APInt> result32 =
 
  110       calculate(lhs.getValue().trunc(32), rhs.getValue().trunc(32));
 
  114   if (result64->trunc(32) != *result32)
 
  123 template <
typename BinaryOp>
 
  130   auto lhsOp = op.getLhs().template getDefiningOp<BinaryOp>();
 
  152           adaptor.getOperands(),
 
  153           [](
const APInt &lhs, 
const APInt &rhs) { return lhs + rhs; }))
 
  156   if (
auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) {
 
  158     if (rhs.getValue().isZero())
 
  165 LogicalResult AddOp::canonicalize(AddOp op, 
PatternRewriter &rewriter) {
 
  175           adaptor.getOperands(),
 
  176           [](
const APInt &lhs, 
const APInt &rhs) { return lhs - rhs; }))
 
  179   if (
auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) {
 
  181     if (rhs.getValue().isZero())
 
  194           adaptor.getOperands(),
 
  195           [](
const APInt &lhs, 
const APInt &rhs) { return lhs * rhs; }))
 
  198   if (
auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) {
 
  200     if (rhs.getValue().isOne())
 
  203     if (rhs.getValue().isZero())
 
  210 LogicalResult MulOp::canonicalize(MulOp op, 
PatternRewriter &rewriter) {
 
  220       adaptor.getOperands(),
 
  221       [](
const APInt &lhs, 
const APInt &rhs) -> std::optional<APInt> {
 
  225         return lhs.sdiv(rhs);
 
  235       adaptor.getOperands(),
 
  236       [](
const APInt &lhs, 
const APInt &rhs) -> std::optional<APInt> {
 
  240         return lhs.udiv(rhs);
 
  258   bool mGtZ = m.sgt(0);
 
  259   if (n.sgt(0) != mGtZ) {
 
  263     return -(-n).sdiv(m);
 
  267   int64_t x = mGtZ ? -1 : 1;
 
  268   return (n + x).sdiv(m) + 1;
 
  282       adaptor.getOperands(),
 
  283       [](
const APInt &n, 
const APInt &m) -> std::optional<APInt> {
 
  291         return (n - 1).udiv(m) + 1;
 
  309   bool mLtZ = m.slt(0);
 
  310   if (n.slt(0) == mLtZ) {
 
  317   int64_t x = mLtZ ? 1 : -1;
 
  318   return -1 - (x - n).sdiv(m);
 
  331       adaptor.getOperands(),
 
  332       [](
const APInt &lhs, 
const APInt &rhs) -> std::optional<APInt> {
 
  336         return lhs.srem(rhs);
 
  346       adaptor.getOperands(),
 
  347       [](
const APInt &lhs, 
const APInt &rhs) -> std::optional<APInt> {
 
  351         return lhs.urem(rhs);
 
  361                              [](
const APInt &lhs, 
const APInt &rhs) {
 
  362                                return lhs.sgt(rhs) ? lhs : rhs;
 
  366 LogicalResult MaxSOp::canonicalize(MaxSOp op, 
PatternRewriter &rewriter) {
 
  376                              [](
const APInt &lhs, 
const APInt &rhs) {
 
  377                                return lhs.ugt(rhs) ? lhs : rhs;
 
  381 LogicalResult MaxUOp::canonicalize(MaxUOp op, 
PatternRewriter &rewriter) {
 
  391                              [](
const APInt &lhs, 
const APInt &rhs) {
 
  392                                return lhs.slt(rhs) ? lhs : rhs;
 
  396 LogicalResult MinSOp::canonicalize(MinSOp op, 
PatternRewriter &rewriter) {
 
  406                              [](
const APInt &lhs, 
const APInt &rhs) {
 
  407                                return lhs.ult(rhs) ? lhs : rhs;
 
  411 LogicalResult MinUOp::canonicalize(MinUOp op, 
PatternRewriter &rewriter) {
 
  421       adaptor.getOperands(),
 
  422       [](
const APInt &lhs, 
const APInt &rhs) -> std::optional<APInt> {
 
  438       adaptor.getOperands(),
 
  439       [](
const APInt &lhs, 
const APInt &rhs) -> std::optional<APInt> {
 
  443         return lhs.ashr(rhs);
 
  453       adaptor.getOperands(),
 
  454       [](
const APInt &lhs, 
const APInt &rhs) -> std::optional<APInt> {
 
  458         return lhs.lshr(rhs);
 
  468       adaptor.getOperands(),
 
  469       [](
const APInt &lhs, 
const APInt &rhs) { return lhs & rhs; });
 
  472 LogicalResult AndOp::canonicalize(AndOp op, 
PatternRewriter &rewriter) {
 
  482       adaptor.getOperands(),
 
  483       [](
const APInt &lhs, 
const APInt &rhs) { return lhs | rhs; });
 
  496       adaptor.getOperands(),
 
  497       [](
const APInt &lhs, 
const APInt &rhs) { return lhs ^ rhs; });
 
  500 LogicalResult XOrOp::canonicalize(XOrOp op, 
PatternRewriter &rewriter) {
 
  511            function_ref<APInt(
const APInt &, 
unsigned)> extOrTruncFn) {
 
  512   auto attr = dyn_cast_if_present<IntegerAttr>(input);
 
  515   const APInt &value = attr.getValue();
 
  517   if (isa<IndexType>(type)) {
 
  521     APInt result = extOrTruncFn(value, 64);
 
  527   auto intType = cast<IntegerType>(type);
 
  528   unsigned width = intType.getWidth();
 
  533     APInt result = value.trunc(width);
 
  540     if (extFn(value.trunc(32), 64) != value)
 
  542     APInt result = extFn(value, width);
 
  547   APInt result = value.trunc(width);
 
  548   if (result != extFn(value.trunc(32), width))
 
  554   return llvm::isa<IndexType>(lhsTypes.front()) !=
 
  555          llvm::isa<IndexType>(rhsTypes.front());
 
  561       [](
const APInt &x, 
unsigned width) { 
return x.sext(width); },
 
  562       [](
const APInt &x, 
unsigned width) { 
return x.sextOrTrunc(width); });
 
  570   return llvm::isa<IndexType>(lhsTypes.front()) !=
 
  571          llvm::isa<IndexType>(rhsTypes.front());
 
  577       [](
const APInt &x, 
unsigned width) { 
return x.zext(width); },
 
  578       [](
const APInt &x, 
unsigned width) { 
return x.zextOrTrunc(width); });
 
  587                     IndexCmpPredicate pred) {
 
  589   case IndexCmpPredicate::EQ:
 
  591   case IndexCmpPredicate::NE:
 
  593   case IndexCmpPredicate::SGE:
 
  595   case IndexCmpPredicate::SGT:
 
  597   case IndexCmpPredicate::SLE:
 
  599   case IndexCmpPredicate::SLT:
 
  601   case IndexCmpPredicate::UGE:
 
  603   case IndexCmpPredicate::UGT:
 
  605   case IndexCmpPredicate::ULE:
 
  607   case IndexCmpPredicate::ULT:
 
  610   llvm_unreachable(
"unhandled IndexCmpPredicate predicate");
 
  619                                              const APInt &cstB, 
unsigned width,
 
  620                                              IndexCmpPredicate pred) {
 
  622                                    .Case([&](MinSOp op) {
 
  623                                      return ConstantIntRanges::fromSigned(
 
  624                                          APInt::getSignedMinValue(width), cstA);
 
  626                                    .Case([&](MinUOp op) {
 
  627                                      return ConstantIntRanges::fromUnsigned(
 
  628                                          APInt::getMinValue(width), cstA);
 
  630                                    .Case([&](MaxSOp op) {
 
  631                                      return ConstantIntRanges::fromSigned(
 
  632                                          cstA, APInt::getSignedMaxValue(width));
 
  634                                    .Case([&](MaxUOp op) {
 
  635                                      return ConstantIntRanges::fromUnsigned(
 
  636                                          cstA, APInt::getMaxValue(width));
 
  639                                 lhsRange, ConstantIntRanges::constant(cstB));
 
  645   case IndexCmpPredicate::EQ:
 
  646   case IndexCmpPredicate::SGE:
 
  647   case IndexCmpPredicate::SLE:
 
  648   case IndexCmpPredicate::UGE:
 
  649   case IndexCmpPredicate::ULE:
 
  651   case IndexCmpPredicate::NE:
 
  652   case IndexCmpPredicate::SGT:
 
  653   case IndexCmpPredicate::SLT:
 
  654   case IndexCmpPredicate::UGT:
 
  655   case IndexCmpPredicate::ULT:
 
  658   llvm_unreachable(
"unknown predicate in compareSameArgs");
 
  663   auto lhs = dyn_cast_if_present<IntegerAttr>(adaptor.getLhs());
 
  664   auto rhs = dyn_cast_if_present<IntegerAttr>(adaptor.getRhs());
 
  667     bool result64 = 
compareIndices(lhs.getValue(), rhs.getValue(), getPred());
 
  669                                    rhs.getValue().trunc(32), getPred());
 
  670     if (result64 == result32)
 
  675   Operation *lhsOp = getLhs().getDefiningOp();
 
  677   if (isa_and_nonnull<MinSOp, MinUOp, MaxSOp, MaxUOp>(lhsOp) &&
 
  680         lhsOp, cstA.getValue(), rhs.getValue(), 64, getPred());
 
  681     std::optional<bool> result32 =
 
  683                           rhs.getValue().trunc(32), 32, getPred());
 
  685     if (result64 && result32 && *result64 == *result32)
 
  690   if (getLhs() == getRhs())
 
  699 LogicalResult CmpOp::canonicalize(CmpOp op, 
PatternRewriter &rewriter) {
 
  704                    cmpRhs.getValue().isZero();
 
  706                    cmpLhs.getValue().isZero();
 
  707   if (!rhsIsZero && !lhsIsZero)
 
  709                                        "cmp is not comparing something with 0");
 
  710   SubOp subOp = rhsIsZero ? op.getLhs().getDefiningOp<index::SubOp>()
 
  711                           : op.getRhs().getDefiningOp<index::SubOp>();
 
  714         op.getLoc(), 
"non-zero operand is not a result of subtraction");
 
  718     newCmp = index::CmpOp::create(rewriter, op.getLoc(), op.getPred(),
 
  719                                   subOp.getLhs(), subOp.getRhs());
 
  721     newCmp = index::CmpOp::create(rewriter, op.getLoc(), op.getPred(),
 
  722                                   subOp.getRhs(), subOp.getLhs());
 
  731 void ConstantOp::getAsmResultNames(
 
  734   llvm::raw_svector_ostream specialName(specialNameBuffer);
 
  735   specialName << 
"idx" << getValueAttr().getValue();
 
  736   setNameFn(getResult(), specialName.str());
 
  739 OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) { 
return getValueAttr(); }
 
  749 OpFoldResult BoolConstantOp::fold(FoldAdaptor adaptor) {
 
  750   return getValueAttr();
 
  753 void BoolConstantOp::getAsmResultNames(
 
  755   setNameFn(getResult(), getValue() ? 
"true" : 
"false");
 
  762 #define GET_OP_CLASSES 
  763 #include "mlir/Dialect/Index/IR/IndexOps.cpp.inc" 
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
static OpFoldResult foldBinaryOpUnchecked(ArrayRef< Attribute > operands, function_ref< std::optional< APInt >(const APInt &, const APInt &)> calculate)
Fold an index operation irrespective of the target bitwidth.
LogicalResult canonicalizeAssociativeCommutativeBinaryOp(BinaryOp op, PatternRewriter &rewriter)
Helper for associative and commutative binary ops that can be transformed: x = op(v,...
bool compareIndices(const APInt &lhs, const APInt &rhs, IndexCmpPredicate pred)
Compare two integers according to the comparison predicate.
static OpFoldResult foldBinaryOpChecked(ArrayRef< Attribute > operands, function_ref< std::optional< APInt >(const APInt &, const APInt &lhs)> calculate)
Fold an index operation only if the truncated 64-bit result matches the 32-bit result for operations ...
static std::optional< bool > foldCmpOfMaxOrMin(Operation *lhsOp, const APInt &cstA, const APInt &cstB, unsigned width, IndexCmpPredicate pred)
cmp(max/min(x, cstA), cstB) can be folded to a constant depending on the values of cstA and cstB,...
static OpFoldResult foldCastOp(Attribute input, Type type, function_ref< APInt(const APInt &, unsigned)> extFn, function_ref< APInt(const APInt &, unsigned)> extOrTruncFn)
static std::optional< APInt > calculateCeilDivS(const APInt &n, const APInt &m)
Compute ceildivs(n, m) as x = m > 0 ? -1 : 1 and then n*m > 0 ? (n+x)/m + 1 : -(-n/m).
static bool compareSameArgs(IndexCmpPredicate pred)
Return the result of cmp(pred, x, x)
static std::optional< APInt > calculateFloorDivS(const APInt &n, const APInt &m)
Compute floordivs(n, m) as x = m < 0 ? 1 : -1 and then n*m < 0 ? -1 - (x-n)/m : n/m.
static MLIRContext * getContext(OpFoldResult val)
Attributes are known-constant values of operations.
IntegerAttr getIndexAttr(int64_t value)
A set of arbitrary-precision integers representing bounds on a given integer value.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isSignlessInteger() const
Return true if this is a signless integer type (with the specified width).
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
std::optional< bool > evaluatePred(CmpPredicate pred, const ConstantIntRanges &lhs, const ConstantIntRanges &rhs)
Returns a boolean value if pred is statically true or false for anypossible inputs falling within lhs...
CmpPredicate
Copy of the enum from arith and index to allow the common integer range infrastructure to not depend ...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
This represents an operation in an abstracted form, suitable for use with the builder APIs.