MLIR  20.0.0git
IndexOps.cpp
Go to the documentation of this file.
1 //===- IndexOps.cpp - Index operation definitions --------------------------==//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
12 #include "mlir/IR/Builders.h"
13 #include "mlir/IR/Matchers.h"
15 #include "mlir/IR/PatternMatch.h"
17 #include "llvm/ADT/SmallString.h"
18 #include "llvm/ADT/TypeSwitch.h"
19 
20 using namespace mlir;
21 using namespace mlir::index;
22 
23 //===----------------------------------------------------------------------===//
24 // IndexDialect
25 //===----------------------------------------------------------------------===//
26 
27 void IndexDialect::registerOperations() {
28  addOperations<
29 #define GET_OP_LIST
30 #include "mlir/Dialect/Index/IR/IndexOps.cpp.inc"
31  >();
32 }
33 
35  Type type, Location loc) {
36  // Materialize bool constants as `i1`.
37  if (auto boolValue = dyn_cast<BoolAttr>(value)) {
38  if (!type.isSignlessInteger(1))
39  return nullptr;
40  return b.create<BoolConstantOp>(loc, type, boolValue);
41  }
42 
43  // Materialize integer attributes as `index`.
44  if (auto indexValue = dyn_cast<IntegerAttr>(value)) {
45  if (!llvm::isa<IndexType>(indexValue.getType()) ||
46  !llvm::isa<IndexType>(type))
47  return nullptr;
48  assert(indexValue.getValue().getBitWidth() ==
49  IndexType::kInternalStorageBitWidth);
50  return b.create<ConstantOp>(loc, indexValue);
51  }
52 
53  return nullptr;
54 }
55 
56 //===----------------------------------------------------------------------===//
57 // Fold Utilities
58 //===----------------------------------------------------------------------===//
59 
60 /// Fold an index operation irrespective of the target bitwidth. The
61 /// operation must satisfy the property:
62 ///
63 /// ```
64 /// trunc(f(a, b)) = f(trunc(a), trunc(b))
65 /// ```
66 ///
67 /// For all values of `a` and `b`. The function accepts a lambda that computes
68 /// the integer result, which in turn must satisfy the above property.
70  ArrayRef<Attribute> operands,
71  function_ref<std::optional<APInt>(const APInt &, const APInt &)>
72  calculate) {
73  assert(operands.size() == 2 && "binary operation expected 2 operands");
74  auto lhs = dyn_cast_if_present<IntegerAttr>(operands[0]);
75  auto rhs = dyn_cast_if_present<IntegerAttr>(operands[1]);
76  if (!lhs || !rhs)
77  return {};
78 
79  std::optional<APInt> result = calculate(lhs.getValue(), rhs.getValue());
80  if (!result)
81  return {};
82  assert(result->trunc(32) ==
83  calculate(lhs.getValue().trunc(32), rhs.getValue().trunc(32)));
84  return IntegerAttr::get(IndexType::get(lhs.getContext()), *result);
85 }
86 
87 /// Fold an index operation only if the truncated 64-bit result matches the
88 /// 32-bit result for operations that don't satisfy the above property. These
89 /// are operations where the upper bits of the operands can affect the lower
90 /// bits of the results.
91 ///
92 /// The function accepts a lambda that computes the integer result in both
93 /// 64-bit and 32-bit. If either call returns `std::nullopt`, the operation is
94 /// not folded.
96  ArrayRef<Attribute> operands,
97  function_ref<std::optional<APInt>(const APInt &, const APInt &lhs)>
98  calculate) {
99  assert(operands.size() == 2 && "binary operation expected 2 operands");
100  auto lhs = dyn_cast_if_present<IntegerAttr>(operands[0]);
101  auto rhs = dyn_cast_if_present<IntegerAttr>(operands[1]);
102  // Only fold index operands.
103  if (!lhs || !rhs)
104  return {};
105 
106  // Compute the 64-bit result and the 32-bit result.
107  std::optional<APInt> result64 = calculate(lhs.getValue(), rhs.getValue());
108  if (!result64)
109  return {};
110  std::optional<APInt> result32 =
111  calculate(lhs.getValue().trunc(32), rhs.getValue().trunc(32));
112  if (!result32)
113  return {};
114  // Compare the truncated 64-bit result to the 32-bit result.
115  if (result64->trunc(32) != *result32)
116  return {};
117  // The operation can be folded for these particular operands.
118  return IntegerAttr::get(IndexType::get(lhs.getContext()), *result64);
119 }
120 
121 /// Helper for associative and commutative binary ops that can be transformed:
122 /// `x = op(v, c1); y = op(x, c2)` -> `tmp = op(c1, c2); y = op(v, tmp)`
123 /// where c1 and c2 are constants. It is expected that `tmp` will be folded.
124 template <typename BinaryOp>
125 LogicalResult
127  PatternRewriter &rewriter) {
128  if (!mlir::matchPattern(op.getRhs(), mlir::m_Constant()))
129  return rewriter.notifyMatchFailure(op.getLoc(), "RHS is not a constant");
130 
131  auto lhsOp = op.getLhs().template getDefiningOp<BinaryOp>();
132  if (!lhsOp)
133  return rewriter.notifyMatchFailure(op.getLoc(), "LHS is not the same BinaryOp");
134 
135  if (!mlir::matchPattern(lhsOp.getRhs(), mlir::m_Constant()))
136  return rewriter.notifyMatchFailure(op.getLoc(), "RHS of LHS op is not a constant");
137 
138  Value c = rewriter.createOrFold<BinaryOp>(op->getLoc(), op.getRhs(),
139  lhsOp.getRhs());
140  if (c.getDefiningOp<BinaryOp>())
141  return rewriter.notifyMatchFailure(op.getLoc(), "new BinaryOp was not folded");
142 
143  rewriter.replaceOpWithNewOp<BinaryOp>(op, lhsOp.getLhs(), c);
144  return success();
145 }
146 
147 //===----------------------------------------------------------------------===//
148 // AddOp
149 //===----------------------------------------------------------------------===//
150 
151 OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
152  if (OpFoldResult result = foldBinaryOpUnchecked(
153  adaptor.getOperands(),
154  [](const APInt &lhs, const APInt &rhs) { return lhs + rhs; }))
155  return result;
156 
157  if (auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) {
158  // Fold `add(x, 0) -> x`.
159  if (rhs.getValue().isZero())
160  return getLhs();
161  }
162 
163  return {};
164 }
165 
166 LogicalResult AddOp::canonicalize(AddOp op, PatternRewriter &rewriter) {
167  return canonicalizeAssociativeCommutativeBinaryOp(op, rewriter);
168 }
169 
170 //===----------------------------------------------------------------------===//
171 // SubOp
172 //===----------------------------------------------------------------------===//
173 
174 OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
175  if (OpFoldResult result = foldBinaryOpUnchecked(
176  adaptor.getOperands(),
177  [](const APInt &lhs, const APInt &rhs) { return lhs - rhs; }))
178  return result;
179 
180  if (auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) {
181  // Fold `sub(x, 0) -> x`.
182  if (rhs.getValue().isZero())
183  return getLhs();
184  }
185 
186  return {};
187 }
188 
189 //===----------------------------------------------------------------------===//
190 // MulOp
191 //===----------------------------------------------------------------------===//
192 
193 OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
194  if (OpFoldResult result = foldBinaryOpUnchecked(
195  adaptor.getOperands(),
196  [](const APInt &lhs, const APInt &rhs) { return lhs * rhs; }))
197  return result;
198 
199  if (auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) {
200  // Fold `mul(x, 1) -> x`.
201  if (rhs.getValue().isOne())
202  return getLhs();
203  // Fold `mul(x, 0) -> 0`.
204  if (rhs.getValue().isZero())
205  return rhs;
206  }
207 
208  return {};
209 }
210 
211 LogicalResult MulOp::canonicalize(MulOp op, PatternRewriter &rewriter) {
212  return canonicalizeAssociativeCommutativeBinaryOp(op, rewriter);
213 }
214 
215 //===----------------------------------------------------------------------===//
216 // DivSOp
217 //===----------------------------------------------------------------------===//
218 
219 OpFoldResult DivSOp::fold(FoldAdaptor adaptor) {
220  return foldBinaryOpChecked(
221  adaptor.getOperands(),
222  [](const APInt &lhs, const APInt &rhs) -> std::optional<APInt> {
223  // Don't fold division by zero.
224  if (rhs.isZero())
225  return std::nullopt;
226  return lhs.sdiv(rhs);
227  });
228 }
229 
230 //===----------------------------------------------------------------------===//
231 // DivUOp
232 //===----------------------------------------------------------------------===//
233 
234 OpFoldResult DivUOp::fold(FoldAdaptor adaptor) {
235  return foldBinaryOpChecked(
236  adaptor.getOperands(),
237  [](const APInt &lhs, const APInt &rhs) -> std::optional<APInt> {
238  // Don't fold division by zero.
239  if (rhs.isZero())
240  return std::nullopt;
241  return lhs.udiv(rhs);
242  });
243 }
244 
245 //===----------------------------------------------------------------------===//
246 // CeilDivSOp
247 //===----------------------------------------------------------------------===//
248 
249 /// Compute `ceildivs(n, m)` as `x = m > 0 ? -1 : 1` and then
250 /// `n*m > 0 ? (n+x)/m + 1 : -(-n/m)`.
251 static std::optional<APInt> calculateCeilDivS(const APInt &n, const APInt &m) {
252  // Don't fold division by zero.
253  if (m.isZero())
254  return std::nullopt;
255  // Short-circuit the zero case.
256  if (n.isZero())
257  return n;
258 
259  bool mGtZ = m.sgt(0);
260  if (n.sgt(0) != mGtZ) {
261  // If the operands have different signs, compute the negative result. Signed
262  // division overflow is not possible, since if `m == -1`, `n` can be at most
263  // `INT_MAX`, and `-INT_MAX != INT_MIN` in two's complement.
264  return -(-n).sdiv(m);
265  }
266  // Otherwise, compute the positive result. Signed division overflow is not
267  // possible since if `m == -1`, `x` will be `1`.
268  int64_t x = mGtZ ? -1 : 1;
269  return (n + x).sdiv(m) + 1;
270 }
271 
272 OpFoldResult CeilDivSOp::fold(FoldAdaptor adaptor) {
273  return foldBinaryOpChecked(adaptor.getOperands(), calculateCeilDivS);
274 }
275 
276 //===----------------------------------------------------------------------===//
277 // CeilDivUOp
278 //===----------------------------------------------------------------------===//
279 
280 OpFoldResult CeilDivUOp::fold(FoldAdaptor adaptor) {
281  // Compute `ceildivu(n, m)` as `n == 0 ? 0 : (n-1)/m + 1`.
282  return foldBinaryOpChecked(
283  adaptor.getOperands(),
284  [](const APInt &n, const APInt &m) -> std::optional<APInt> {
285  // Don't fold division by zero.
286  if (m.isZero())
287  return std::nullopt;
288  // Short-circuit the zero case.
289  if (n.isZero())
290  return n;
291 
292  return (n - 1).udiv(m) + 1;
293  });
294 }
295 
296 //===----------------------------------------------------------------------===//
297 // FloorDivSOp
298 //===----------------------------------------------------------------------===//
299 
300 /// Compute `floordivs(n, m)` as `x = m < 0 ? 1 : -1` and then
301 /// `n*m < 0 ? -1 - (x-n)/m : n/m`.
302 static std::optional<APInt> calculateFloorDivS(const APInt &n, const APInt &m) {
303  // Don't fold division by zero.
304  if (m.isZero())
305  return std::nullopt;
306  // Short-circuit the zero case.
307  if (n.isZero())
308  return n;
309 
310  bool mLtZ = m.slt(0);
311  if (n.slt(0) == mLtZ) {
312  // If the operands have the same sign, compute the positive result.
313  return n.sdiv(m);
314  }
315  // If the operands have different signs, compute the negative result. Signed
316  // division overflow is not possible since if `m == -1`, `x` will be 1 and
317  // `n` can be at most `INT_MAX`.
318  int64_t x = mLtZ ? 1 : -1;
319  return -1 - (x - n).sdiv(m);
320 }
321 
322 OpFoldResult FloorDivSOp::fold(FoldAdaptor adaptor) {
323  return foldBinaryOpChecked(adaptor.getOperands(), calculateFloorDivS);
324 }
325 
326 //===----------------------------------------------------------------------===//
327 // RemSOp
328 //===----------------------------------------------------------------------===//
329 
330 OpFoldResult RemSOp::fold(FoldAdaptor adaptor) {
331  return foldBinaryOpChecked(
332  adaptor.getOperands(),
333  [](const APInt &lhs, const APInt &rhs) -> std::optional<APInt> {
334  // Don't fold division by zero.
335  if (rhs.isZero())
336  return std::nullopt;
337  return lhs.srem(rhs);
338  });
339 }
340 
341 //===----------------------------------------------------------------------===//
342 // RemUOp
343 //===----------------------------------------------------------------------===//
344 
345 OpFoldResult RemUOp::fold(FoldAdaptor adaptor) {
346  return foldBinaryOpChecked(
347  adaptor.getOperands(),
348  [](const APInt &lhs, const APInt &rhs) -> std::optional<APInt> {
349  // Don't fold division by zero.
350  if (rhs.isZero())
351  return std::nullopt;
352  return lhs.urem(rhs);
353  });
354 }
355 
356 //===----------------------------------------------------------------------===//
357 // MaxSOp
358 //===----------------------------------------------------------------------===//
359 
360 OpFoldResult MaxSOp::fold(FoldAdaptor adaptor) {
361  return foldBinaryOpChecked(adaptor.getOperands(),
362  [](const APInt &lhs, const APInt &rhs) {
363  return lhs.sgt(rhs) ? lhs : rhs;
364  });
365 }
366 
367 LogicalResult MaxSOp::canonicalize(MaxSOp op, PatternRewriter &rewriter) {
368  return canonicalizeAssociativeCommutativeBinaryOp(op, rewriter);
369 }
370 
371 //===----------------------------------------------------------------------===//
372 // MaxUOp
373 //===----------------------------------------------------------------------===//
374 
375 OpFoldResult MaxUOp::fold(FoldAdaptor adaptor) {
376  return foldBinaryOpChecked(adaptor.getOperands(),
377  [](const APInt &lhs, const APInt &rhs) {
378  return lhs.ugt(rhs) ? lhs : rhs;
379  });
380 }
381 
382 LogicalResult MaxUOp::canonicalize(MaxUOp op, PatternRewriter &rewriter) {
383  return canonicalizeAssociativeCommutativeBinaryOp(op, rewriter);
384 }
385 
386 //===----------------------------------------------------------------------===//
387 // MinSOp
388 //===----------------------------------------------------------------------===//
389 
390 OpFoldResult MinSOp::fold(FoldAdaptor adaptor) {
391  return foldBinaryOpChecked(adaptor.getOperands(),
392  [](const APInt &lhs, const APInt &rhs) {
393  return lhs.slt(rhs) ? lhs : rhs;
394  });
395 }
396 
397 LogicalResult MinSOp::canonicalize(MinSOp op, PatternRewriter &rewriter) {
398  return canonicalizeAssociativeCommutativeBinaryOp(op, rewriter);
399 }
400 
401 //===----------------------------------------------------------------------===//
402 // MinUOp
403 //===----------------------------------------------------------------------===//
404 
405 OpFoldResult MinUOp::fold(FoldAdaptor adaptor) {
406  return foldBinaryOpChecked(adaptor.getOperands(),
407  [](const APInt &lhs, const APInt &rhs) {
408  return lhs.ult(rhs) ? lhs : rhs;
409  });
410 }
411 
412 LogicalResult MinUOp::canonicalize(MinUOp op, PatternRewriter &rewriter) {
413  return canonicalizeAssociativeCommutativeBinaryOp(op, rewriter);
414 }
415 
416 //===----------------------------------------------------------------------===//
417 // ShlOp
418 //===----------------------------------------------------------------------===//
419 
420 OpFoldResult ShlOp::fold(FoldAdaptor adaptor) {
421  return foldBinaryOpUnchecked(
422  adaptor.getOperands(),
423  [](const APInt &lhs, const APInt &rhs) -> std::optional<APInt> {
424  // We cannot fold if the RHS is greater than or equal to 32 because
425  // this would be UB in 32-bit systems but not on 64-bit systems. RHS is
426  // already treated as unsigned.
427  if (rhs.uge(32))
428  return {};
429  return lhs << rhs;
430  });
431 }
432 
433 //===----------------------------------------------------------------------===//
434 // ShrSOp
435 //===----------------------------------------------------------------------===//
436 
437 OpFoldResult ShrSOp::fold(FoldAdaptor adaptor) {
438  return foldBinaryOpChecked(
439  adaptor.getOperands(),
440  [](const APInt &lhs, const APInt &rhs) -> std::optional<APInt> {
441  // Don't fold if RHS is greater than or equal to 32.
442  if (rhs.uge(32))
443  return {};
444  return lhs.ashr(rhs);
445  });
446 }
447 
448 //===----------------------------------------------------------------------===//
449 // ShrUOp
450 //===----------------------------------------------------------------------===//
451 
452 OpFoldResult ShrUOp::fold(FoldAdaptor adaptor) {
453  return foldBinaryOpChecked(
454  adaptor.getOperands(),
455  [](const APInt &lhs, const APInt &rhs) -> std::optional<APInt> {
456  // Don't fold if RHS is greater than or equal to 32.
457  if (rhs.uge(32))
458  return {};
459  return lhs.lshr(rhs);
460  });
461 }
462 
463 //===----------------------------------------------------------------------===//
464 // AndOp
465 //===----------------------------------------------------------------------===//
466 
467 OpFoldResult AndOp::fold(FoldAdaptor adaptor) {
468  return foldBinaryOpUnchecked(
469  adaptor.getOperands(),
470  [](const APInt &lhs, const APInt &rhs) { return lhs & rhs; });
471 }
472 
473 LogicalResult AndOp::canonicalize(AndOp op, PatternRewriter &rewriter) {
474  return canonicalizeAssociativeCommutativeBinaryOp(op, rewriter);
475 }
476 
477 //===----------------------------------------------------------------------===//
478 // OrOp
479 //===----------------------------------------------------------------------===//
480 
481 OpFoldResult OrOp::fold(FoldAdaptor adaptor) {
482  return foldBinaryOpUnchecked(
483  adaptor.getOperands(),
484  [](const APInt &lhs, const APInt &rhs) { return lhs | rhs; });
485 }
486 
487 LogicalResult OrOp::canonicalize(OrOp op, PatternRewriter &rewriter) {
488  return canonicalizeAssociativeCommutativeBinaryOp(op, rewriter);
489 }
490 
491 //===----------------------------------------------------------------------===//
492 // XOrOp
493 //===----------------------------------------------------------------------===//
494 
495 OpFoldResult XOrOp::fold(FoldAdaptor adaptor) {
496  return foldBinaryOpUnchecked(
497  adaptor.getOperands(),
498  [](const APInt &lhs, const APInt &rhs) { return lhs ^ rhs; });
499 }
500 
501 LogicalResult XOrOp::canonicalize(XOrOp op, PatternRewriter &rewriter) {
502  return canonicalizeAssociativeCommutativeBinaryOp(op, rewriter);
503 }
504 
505 //===----------------------------------------------------------------------===//
506 // CastSOp
507 //===----------------------------------------------------------------------===//
508 
509 static OpFoldResult
511  function_ref<APInt(const APInt &, unsigned)> extFn,
512  function_ref<APInt(const APInt &, unsigned)> extOrTruncFn) {
513  auto attr = dyn_cast_if_present<IntegerAttr>(input);
514  if (!attr)
515  return {};
516  const APInt &value = attr.getValue();
517 
518  if (isa<IndexType>(type)) {
519  // When casting to an index type, perform the cast assuming a 64-bit target.
520  // The result can be truncated to 32 bits as needed and always be correct.
521  // This is because `cast32(cast64(value)) == cast32(value)`.
522  APInt result = extOrTruncFn(value, 64);
523  return IntegerAttr::get(type, result);
524  }
525 
526  // When casting from an index type, we must ensure the results respect
527  // `cast_t(value) == cast_t(trunc32(value))`.
528  auto intType = cast<IntegerType>(type);
529  unsigned width = intType.getWidth();
530 
531  // If the result type is at most 32 bits, then the cast can always be folded
532  // because it is always a truncation.
533  if (width <= 32) {
534  APInt result = value.trunc(width);
535  return IntegerAttr::get(type, result);
536  }
537 
538  // If the result type is at least 64 bits, then the cast is always a
539  // extension. The results will differ if `trunc32(value) != value)`.
540  if (width >= 64) {
541  if (extFn(value.trunc(32), 64) != value)
542  return {};
543  APInt result = extFn(value, width);
544  return IntegerAttr::get(type, result);
545  }
546 
547  // Otherwise, we just have to check the property directly.
548  APInt result = value.trunc(width);
549  if (result != extFn(value.trunc(32), width))
550  return {};
551  return IntegerAttr::get(type, result);
552 }
553 
554 bool CastSOp::areCastCompatible(TypeRange lhsTypes, TypeRange rhsTypes) {
555  return llvm::isa<IndexType>(lhsTypes.front()) !=
556  llvm::isa<IndexType>(rhsTypes.front());
557 }
558 
559 OpFoldResult CastSOp::fold(FoldAdaptor adaptor) {
560  return foldCastOp(
561  adaptor.getInput(), getType(),
562  [](const APInt &x, unsigned width) { return x.sext(width); },
563  [](const APInt &x, unsigned width) { return x.sextOrTrunc(width); });
564 }
565 
566 //===----------------------------------------------------------------------===//
567 // CastUOp
568 //===----------------------------------------------------------------------===//
569 
570 bool CastUOp::areCastCompatible(TypeRange lhsTypes, TypeRange rhsTypes) {
571  return llvm::isa<IndexType>(lhsTypes.front()) !=
572  llvm::isa<IndexType>(rhsTypes.front());
573 }
574 
575 OpFoldResult CastUOp::fold(FoldAdaptor adaptor) {
576  return foldCastOp(
577  adaptor.getInput(), getType(),
578  [](const APInt &x, unsigned width) { return x.zext(width); },
579  [](const APInt &x, unsigned width) { return x.zextOrTrunc(width); });
580 }
581 
582 //===----------------------------------------------------------------------===//
583 // CmpOp
584 //===----------------------------------------------------------------------===//
585 
586 /// Compare two integers according to the comparison predicate.
587 bool compareIndices(const APInt &lhs, const APInt &rhs,
588  IndexCmpPredicate pred) {
589  switch (pred) {
590  case IndexCmpPredicate::EQ:
591  return lhs.eq(rhs);
592  case IndexCmpPredicate::NE:
593  return lhs.ne(rhs);
594  case IndexCmpPredicate::SGE:
595  return lhs.sge(rhs);
596  case IndexCmpPredicate::SGT:
597  return lhs.sgt(rhs);
598  case IndexCmpPredicate::SLE:
599  return lhs.sle(rhs);
600  case IndexCmpPredicate::SLT:
601  return lhs.slt(rhs);
602  case IndexCmpPredicate::UGE:
603  return lhs.uge(rhs);
604  case IndexCmpPredicate::UGT:
605  return lhs.ugt(rhs);
606  case IndexCmpPredicate::ULE:
607  return lhs.ule(rhs);
608  case IndexCmpPredicate::ULT:
609  return lhs.ult(rhs);
610  }
611  llvm_unreachable("unhandled IndexCmpPredicate predicate");
612 }
613 
614 /// `cmp(max/min(x, cstA), cstB)` can be folded to a constant depending on the
615 /// values of `cstA` and `cstB`, the max or min operation, and the comparison
616 /// predicate. Check whether the value folds in both 32-bit and 64-bit
617 /// arithmetic and to the same value.
618 static std::optional<bool> foldCmpOfMaxOrMin(Operation *lhsOp,
619  const APInt &cstA,
620  const APInt &cstB, unsigned width,
621  IndexCmpPredicate pred) {
623  .Case([&](MinSOp op) {
624  return ConstantIntRanges::fromSigned(
625  APInt::getSignedMinValue(width), cstA);
626  })
627  .Case([&](MinUOp op) {
628  return ConstantIntRanges::fromUnsigned(
629  APInt::getMinValue(width), cstA);
630  })
631  .Case([&](MaxSOp op) {
632  return ConstantIntRanges::fromSigned(
633  cstA, APInt::getSignedMaxValue(width));
634  })
635  .Case([&](MaxUOp op) {
636  return ConstantIntRanges::fromUnsigned(
637  cstA, APInt::getMaxValue(width));
638  });
639  return intrange::evaluatePred(static_cast<intrange::CmpPredicate>(pred),
640  lhsRange, ConstantIntRanges::constant(cstB));
641 }
642 
643 /// Return the result of `cmp(pred, x, x)`
644 static bool compareSameArgs(IndexCmpPredicate pred) {
645  switch (pred) {
646  case IndexCmpPredicate::EQ:
647  case IndexCmpPredicate::SGE:
648  case IndexCmpPredicate::SLE:
649  case IndexCmpPredicate::UGE:
650  case IndexCmpPredicate::ULE:
651  return true;
652  case IndexCmpPredicate::NE:
653  case IndexCmpPredicate::SGT:
654  case IndexCmpPredicate::SLT:
655  case IndexCmpPredicate::UGT:
656  case IndexCmpPredicate::ULT:
657  return false;
658  }
659  llvm_unreachable("unknown predicate in compareSameArgs");
660 }
661 
662 OpFoldResult CmpOp::fold(FoldAdaptor adaptor) {
663  // Attempt to fold if both inputs are constant.
664  auto lhs = dyn_cast_if_present<IntegerAttr>(adaptor.getLhs());
665  auto rhs = dyn_cast_if_present<IntegerAttr>(adaptor.getRhs());
666  if (lhs && rhs) {
667  // Perform the comparison in 64-bit and 32-bit.
668  bool result64 = compareIndices(lhs.getValue(), rhs.getValue(), getPred());
669  bool result32 = compareIndices(lhs.getValue().trunc(32),
670  rhs.getValue().trunc(32), getPred());
671  if (result64 == result32)
672  return BoolAttr::get(getContext(), result64);
673  }
674 
675  // Fold `cmp(max/min(x, cstA), cstB)`.
676  Operation *lhsOp = getLhs().getDefiningOp();
677  IntegerAttr cstA;
678  if (isa_and_nonnull<MinSOp, MinUOp, MaxSOp, MaxUOp>(lhsOp) &&
679  matchPattern(lhsOp->getOperand(1), m_Constant(&cstA)) && rhs) {
680  std::optional<bool> result64 = foldCmpOfMaxOrMin(
681  lhsOp, cstA.getValue(), rhs.getValue(), 64, getPred());
682  std::optional<bool> result32 =
683  foldCmpOfMaxOrMin(lhsOp, cstA.getValue().trunc(32),
684  rhs.getValue().trunc(32), 32, getPred());
685  // Fold if the 32-bit and 64-bit results are the same.
686  if (result64 && result32 && *result64 == *result32)
687  return BoolAttr::get(getContext(), *result64);
688  }
689 
690  // Fold `cmp(x, x)`
691  if (getLhs() == getRhs())
692  return BoolAttr::get(getContext(), compareSameArgs(getPred()));
693 
694  return {};
695 }
696 
697 /// Canonicalize
698 /// `x - y cmp 0` to `x cmp y`. or `x - y cmp 0` to `x cmp y`.
699 /// `0 cmp x - y` to `y cmp x`. or `0 cmp x - y` to `y cmp x`.
700 LogicalResult CmpOp::canonicalize(CmpOp op, PatternRewriter &rewriter) {
701  IntegerAttr cmpRhs;
702  IntegerAttr cmpLhs;
703 
704  bool rhsIsZero = matchPattern(op.getRhs(), m_Constant(&cmpRhs)) &&
705  cmpRhs.getValue().isZero();
706  bool lhsIsZero = matchPattern(op.getLhs(), m_Constant(&cmpLhs)) &&
707  cmpLhs.getValue().isZero();
708  if (!rhsIsZero && !lhsIsZero)
709  return rewriter.notifyMatchFailure(op.getLoc(),
710  "cmp is not comparing something with 0");
711  SubOp subOp = rhsIsZero ? op.getLhs().getDefiningOp<index::SubOp>()
712  : op.getRhs().getDefiningOp<index::SubOp>();
713  if (!subOp)
714  return rewriter.notifyMatchFailure(
715  op.getLoc(), "non-zero operand is not a result of subtraction");
716 
717  index::CmpOp newCmp;
718  if (rhsIsZero)
719  newCmp = rewriter.create<index::CmpOp>(op.getLoc(), op.getPred(),
720  subOp.getLhs(), subOp.getRhs());
721  else
722  newCmp = rewriter.create<index::CmpOp>(op.getLoc(), op.getPred(),
723  subOp.getRhs(), subOp.getLhs());
724  rewriter.replaceOp(op, newCmp);
725  return success();
726 }
727 
728 //===----------------------------------------------------------------------===//
729 // ConstantOp
730 //===----------------------------------------------------------------------===//
731 
732 void ConstantOp::getAsmResultNames(
733  function_ref<void(Value, StringRef)> setNameFn) {
734  SmallString<32> specialNameBuffer;
735  llvm::raw_svector_ostream specialName(specialNameBuffer);
736  specialName << "idx" << getValueAttr().getValue();
737  setNameFn(getResult(), specialName.str());
738 }
739 
740 OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) { return getValueAttr(); }
741 
742 void ConstantOp::build(OpBuilder &b, OperationState &state, int64_t value) {
743  build(b, state, b.getIndexType(), b.getIndexAttr(value));
744 }
745 
746 //===----------------------------------------------------------------------===//
747 // BoolConstantOp
748 //===----------------------------------------------------------------------===//
749 
750 OpFoldResult BoolConstantOp::fold(FoldAdaptor adaptor) {
751  return getValueAttr();
752 }
753 
754 void BoolConstantOp::getAsmResultNames(
755  function_ref<void(Value, StringRef)> setNameFn) {
756  setNameFn(getResult(), getValue() ? "true" : "false");
757 }
758 
759 //===----------------------------------------------------------------------===//
760 // ODS-Generated Definitions
761 //===----------------------------------------------------------------------===//
762 
763 #define GET_OP_CLASSES
764 #include "mlir/Dialect/Index/IR/IndexOps.cpp.inc"
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static OpFoldResult foldBinaryOpUnchecked(ArrayRef< Attribute > operands, function_ref< std::optional< APInt >(const APInt &, const APInt &)> calculate)
Fold an index operation irrespective of the target bitwidth.
Definition: IndexOps.cpp:69
LogicalResult canonicalizeAssociativeCommutativeBinaryOp(BinaryOp op, PatternRewriter &rewriter)
Helper for associative and commutative binary ops that can be transformed: x = op(v,...
Definition: IndexOps.cpp:126
bool compareIndices(const APInt &lhs, const APInt &rhs, IndexCmpPredicate pred)
Compare two integers according to the comparison predicate.
Definition: IndexOps.cpp:587
static OpFoldResult foldBinaryOpChecked(ArrayRef< Attribute > operands, function_ref< std::optional< APInt >(const APInt &, const APInt &lhs)> calculate)
Fold an index operation only if the truncated 64-bit result matches the 32-bit result for operations ...
Definition: IndexOps.cpp:95
static std::optional< bool > foldCmpOfMaxOrMin(Operation *lhsOp, const APInt &cstA, const APInt &cstB, unsigned width, IndexCmpPredicate pred)
cmp(max/min(x, cstA), cstB) can be folded to a constant depending on the values of cstA and cstB,...
Definition: IndexOps.cpp:618
static OpFoldResult foldCastOp(Attribute input, Type type, function_ref< APInt(const APInt &, unsigned)> extFn, function_ref< APInt(const APInt &, unsigned)> extOrTruncFn)
Definition: IndexOps.cpp:510
static std::optional< APInt > calculateCeilDivS(const APInt &n, const APInt &m)
Compute ceildivs(n, m) as x = m > 0 ? -1 : 1 and then n*m > 0 ? (n+x)/m + 1 : -(-n/m).
Definition: IndexOps.cpp:251
static bool compareSameArgs(IndexCmpPredicate pred)
Return the result of cmp(pred, x, x)
Definition: IndexOps.cpp:644
static std::optional< APInt > calculateFloorDivS(const APInt &n, const APInt &m)
Compute floordivs(n, m) as x = m < 0 ? 1 : -1 and then n*m < 0 ? -1 - (x-n)/m : n/m.
Definition: IndexOps.cpp:302
static MLIRContext * getContext(OpFoldResult val)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:148
IndexType getIndexType()
Definition: Builders.cpp:95
A set of arbitrary-precision integers representing bounds on a given integer value.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
This class helps build Operations.
Definition: Builders.h:216
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:529
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:350
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:724
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:542
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isSignlessInteger() const
Return true if this is a signless integer type (with the specified width).
Definition: Types.cpp:75
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
std::optional< bool > evaluatePred(CmpPredicate pred, const ConstantIntRanges &lhs, const ConstantIntRanges &rhs)
Returns a boolean value if pred is statically true or false for anypossible inputs falling within lhs...
CmpPredicate
Copy of the enum from arith and index to allow the common integer range infrastructure to not depend ...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
This represents an operation in an abstracted form, suitable for use with the builder APIs.