MLIR  22.0.0git
IndexOps.cpp
Go to the documentation of this file.
1 //===- IndexOps.cpp - Index operation definitions --------------------------==//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
12 #include "mlir/IR/Builders.h"
13 #include "mlir/IR/Matchers.h"
14 #include "mlir/IR/PatternMatch.h"
16 #include "llvm/ADT/SmallString.h"
17 #include "llvm/ADT/TypeSwitch.h"
18 
19 using namespace mlir;
20 using namespace mlir::index;
21 
22 //===----------------------------------------------------------------------===//
23 // IndexDialect
24 //===----------------------------------------------------------------------===//
25 
26 void IndexDialect::registerOperations() {
27  addOperations<
28 #define GET_OP_LIST
29 #include "mlir/Dialect/Index/IR/IndexOps.cpp.inc"
30  >();
31 }
32 
34  Type type, Location loc) {
35  // Materialize bool constants as `i1`.
36  if (auto boolValue = dyn_cast<BoolAttr>(value)) {
37  if (!type.isSignlessInteger(1))
38  return nullptr;
39  return BoolConstantOp::create(b, loc, type, boolValue);
40  }
41 
42  // Materialize integer attributes as `index`.
43  if (auto indexValue = dyn_cast<IntegerAttr>(value)) {
44  if (!llvm::isa<IndexType>(indexValue.getType()) ||
45  !llvm::isa<IndexType>(type))
46  return nullptr;
47  assert(indexValue.getValue().getBitWidth() ==
48  IndexType::kInternalStorageBitWidth);
49  return ConstantOp::create(b, loc, indexValue);
50  }
51 
52  return nullptr;
53 }
54 
55 //===----------------------------------------------------------------------===//
56 // Fold Utilities
57 //===----------------------------------------------------------------------===//
58 
59 /// Fold an index operation irrespective of the target bitwidth. The
60 /// operation must satisfy the property:
61 ///
62 /// ```
63 /// trunc(f(a, b)) = f(trunc(a), trunc(b))
64 /// ```
65 ///
66 /// For all values of `a` and `b`. The function accepts a lambda that computes
67 /// the integer result, which in turn must satisfy the above property.
69  ArrayRef<Attribute> operands,
70  function_ref<std::optional<APInt>(const APInt &, const APInt &)>
71  calculate) {
72  assert(operands.size() == 2 && "binary operation expected 2 operands");
73  auto lhs = dyn_cast_if_present<IntegerAttr>(operands[0]);
74  auto rhs = dyn_cast_if_present<IntegerAttr>(operands[1]);
75  if (!lhs || !rhs)
76  return {};
77 
78  std::optional<APInt> result = calculate(lhs.getValue(), rhs.getValue());
79  if (!result)
80  return {};
81  assert(result->trunc(32) ==
82  calculate(lhs.getValue().trunc(32), rhs.getValue().trunc(32)));
83  return IntegerAttr::get(IndexType::get(lhs.getContext()), *result);
84 }
85 
86 /// Fold an index operation only if the truncated 64-bit result matches the
87 /// 32-bit result for operations that don't satisfy the above property. These
88 /// are operations where the upper bits of the operands can affect the lower
89 /// bits of the results.
90 ///
91 /// The function accepts a lambda that computes the integer result in both
92 /// 64-bit and 32-bit. If either call returns `std::nullopt`, the operation is
93 /// not folded.
95  ArrayRef<Attribute> operands,
96  function_ref<std::optional<APInt>(const APInt &, const APInt &lhs)>
97  calculate) {
98  assert(operands.size() == 2 && "binary operation expected 2 operands");
99  auto lhs = dyn_cast_if_present<IntegerAttr>(operands[0]);
100  auto rhs = dyn_cast_if_present<IntegerAttr>(operands[1]);
101  // Only fold index operands.
102  if (!lhs || !rhs)
103  return {};
104 
105  // Compute the 64-bit result and the 32-bit result.
106  std::optional<APInt> result64 = calculate(lhs.getValue(), rhs.getValue());
107  if (!result64)
108  return {};
109  std::optional<APInt> result32 =
110  calculate(lhs.getValue().trunc(32), rhs.getValue().trunc(32));
111  if (!result32)
112  return {};
113  // Compare the truncated 64-bit result to the 32-bit result.
114  if (result64->trunc(32) != *result32)
115  return {};
116  // The operation can be folded for these particular operands.
117  return IntegerAttr::get(IndexType::get(lhs.getContext()), *result64);
118 }
119 
120 /// Helper for associative and commutative binary ops that can be transformed:
121 /// `x = op(v, c1); y = op(x, c2)` -> `tmp = op(c1, c2); y = op(v, tmp)`
122 /// where c1 and c2 are constants. It is expected that `tmp` will be folded.
123 template <typename BinaryOp>
124 LogicalResult
126  PatternRewriter &rewriter) {
127  if (!mlir::matchPattern(op.getRhs(), mlir::m_Constant()))
128  return rewriter.notifyMatchFailure(op.getLoc(), "RHS is not a constant");
129 
130  auto lhsOp = op.getLhs().template getDefiningOp<BinaryOp>();
131  if (!lhsOp)
132  return rewriter.notifyMatchFailure(op.getLoc(), "LHS is not the same BinaryOp");
133 
134  if (!mlir::matchPattern(lhsOp.getRhs(), mlir::m_Constant()))
135  return rewriter.notifyMatchFailure(op.getLoc(), "RHS of LHS op is not a constant");
136 
137  Value c = rewriter.createOrFold<BinaryOp>(op->getLoc(), op.getRhs(),
138  lhsOp.getRhs());
139  if (c.getDefiningOp<BinaryOp>())
140  return rewriter.notifyMatchFailure(op.getLoc(), "new BinaryOp was not folded");
141 
142  rewriter.replaceOpWithNewOp<BinaryOp>(op, lhsOp.getLhs(), c);
143  return success();
144 }
145 
146 //===----------------------------------------------------------------------===//
147 // AddOp
148 //===----------------------------------------------------------------------===//
149 
150 OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
151  if (OpFoldResult result = foldBinaryOpUnchecked(
152  adaptor.getOperands(),
153  [](const APInt &lhs, const APInt &rhs) { return lhs + rhs; }))
154  return result;
155 
156  if (auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) {
157  // Fold `add(x, 0) -> x`.
158  if (rhs.getValue().isZero())
159  return getLhs();
160  }
161 
162  return {};
163 }
164 
165 LogicalResult AddOp::canonicalize(AddOp op, PatternRewriter &rewriter) {
166  return canonicalizeAssociativeCommutativeBinaryOp(op, rewriter);
167 }
168 
169 //===----------------------------------------------------------------------===//
170 // SubOp
171 //===----------------------------------------------------------------------===//
172 
173 OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
174  if (OpFoldResult result = foldBinaryOpUnchecked(
175  adaptor.getOperands(),
176  [](const APInt &lhs, const APInt &rhs) { return lhs - rhs; }))
177  return result;
178 
179  if (auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) {
180  // Fold `sub(x, 0) -> x`.
181  if (rhs.getValue().isZero())
182  return getLhs();
183  }
184 
185  return {};
186 }
187 
188 //===----------------------------------------------------------------------===//
189 // MulOp
190 //===----------------------------------------------------------------------===//
191 
192 OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
193  if (OpFoldResult result = foldBinaryOpUnchecked(
194  adaptor.getOperands(),
195  [](const APInt &lhs, const APInt &rhs) { return lhs * rhs; }))
196  return result;
197 
198  if (auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) {
199  // Fold `mul(x, 1) -> x`.
200  if (rhs.getValue().isOne())
201  return getLhs();
202  // Fold `mul(x, 0) -> 0`.
203  if (rhs.getValue().isZero())
204  return rhs;
205  }
206 
207  return {};
208 }
209 
210 LogicalResult MulOp::canonicalize(MulOp op, PatternRewriter &rewriter) {
211  return canonicalizeAssociativeCommutativeBinaryOp(op, rewriter);
212 }
213 
214 //===----------------------------------------------------------------------===//
215 // DivSOp
216 //===----------------------------------------------------------------------===//
217 
218 OpFoldResult DivSOp::fold(FoldAdaptor adaptor) {
219  return foldBinaryOpChecked(
220  adaptor.getOperands(),
221  [](const APInt &lhs, const APInt &rhs) -> std::optional<APInt> {
222  // Don't fold division by zero.
223  if (rhs.isZero())
224  return std::nullopt;
225  return lhs.sdiv(rhs);
226  });
227 }
228 
229 //===----------------------------------------------------------------------===//
230 // DivUOp
231 //===----------------------------------------------------------------------===//
232 
233 OpFoldResult DivUOp::fold(FoldAdaptor adaptor) {
234  return foldBinaryOpChecked(
235  adaptor.getOperands(),
236  [](const APInt &lhs, const APInt &rhs) -> std::optional<APInt> {
237  // Don't fold division by zero.
238  if (rhs.isZero())
239  return std::nullopt;
240  return lhs.udiv(rhs);
241  });
242 }
243 
244 //===----------------------------------------------------------------------===//
245 // CeilDivSOp
246 //===----------------------------------------------------------------------===//
247 
248 /// Compute `ceildivs(n, m)` as `x = m > 0 ? -1 : 1` and then
249 /// `n*m > 0 ? (n+x)/m + 1 : -(-n/m)`.
250 static std::optional<APInt> calculateCeilDivS(const APInt &n, const APInt &m) {
251  // Don't fold division by zero.
252  if (m.isZero())
253  return std::nullopt;
254  // Short-circuit the zero case.
255  if (n.isZero())
256  return n;
257 
258  bool mGtZ = m.sgt(0);
259  if (n.sgt(0) != mGtZ) {
260  // If the operands have different signs, compute the negative result. Signed
261  // division overflow is not possible, since if `m == -1`, `n` can be at most
262  // `INT_MAX`, and `-INT_MAX != INT_MIN` in two's complement.
263  return -(-n).sdiv(m);
264  }
265  // Otherwise, compute the positive result. Signed division overflow is not
266  // possible since if `m == -1`, `x` will be `1`.
267  int64_t x = mGtZ ? -1 : 1;
268  return (n + x).sdiv(m) + 1;
269 }
270 
271 OpFoldResult CeilDivSOp::fold(FoldAdaptor adaptor) {
272  return foldBinaryOpChecked(adaptor.getOperands(), calculateCeilDivS);
273 }
274 
275 //===----------------------------------------------------------------------===//
276 // CeilDivUOp
277 //===----------------------------------------------------------------------===//
278 
279 OpFoldResult CeilDivUOp::fold(FoldAdaptor adaptor) {
280  // Compute `ceildivu(n, m)` as `n == 0 ? 0 : (n-1)/m + 1`.
281  return foldBinaryOpChecked(
282  adaptor.getOperands(),
283  [](const APInt &n, const APInt &m) -> std::optional<APInt> {
284  // Don't fold division by zero.
285  if (m.isZero())
286  return std::nullopt;
287  // Short-circuit the zero case.
288  if (n.isZero())
289  return n;
290 
291  return (n - 1).udiv(m) + 1;
292  });
293 }
294 
295 //===----------------------------------------------------------------------===//
296 // FloorDivSOp
297 //===----------------------------------------------------------------------===//
298 
299 /// Compute `floordivs(n, m)` as `x = m < 0 ? 1 : -1` and then
300 /// `n*m < 0 ? -1 - (x-n)/m : n/m`.
301 static std::optional<APInt> calculateFloorDivS(const APInt &n, const APInt &m) {
302  // Don't fold division by zero.
303  if (m.isZero())
304  return std::nullopt;
305  // Short-circuit the zero case.
306  if (n.isZero())
307  return n;
308 
309  bool mLtZ = m.slt(0);
310  if (n.slt(0) == mLtZ) {
311  // If the operands have the same sign, compute the positive result.
312  return n.sdiv(m);
313  }
314  // If the operands have different signs, compute the negative result. Signed
315  // division overflow is not possible since if `m == -1`, `x` will be 1 and
316  // `n` can be at most `INT_MAX`.
317  int64_t x = mLtZ ? 1 : -1;
318  return -1 - (x - n).sdiv(m);
319 }
320 
321 OpFoldResult FloorDivSOp::fold(FoldAdaptor adaptor) {
322  return foldBinaryOpChecked(adaptor.getOperands(), calculateFloorDivS);
323 }
324 
325 //===----------------------------------------------------------------------===//
326 // RemSOp
327 //===----------------------------------------------------------------------===//
328 
329 OpFoldResult RemSOp::fold(FoldAdaptor adaptor) {
330  return foldBinaryOpChecked(
331  adaptor.getOperands(),
332  [](const APInt &lhs, const APInt &rhs) -> std::optional<APInt> {
333  // Don't fold division by zero.
334  if (rhs.isZero())
335  return std::nullopt;
336  return lhs.srem(rhs);
337  });
338 }
339 
340 //===----------------------------------------------------------------------===//
341 // RemUOp
342 //===----------------------------------------------------------------------===//
343 
344 OpFoldResult RemUOp::fold(FoldAdaptor adaptor) {
345  return foldBinaryOpChecked(
346  adaptor.getOperands(),
347  [](const APInt &lhs, const APInt &rhs) -> std::optional<APInt> {
348  // Don't fold division by zero.
349  if (rhs.isZero())
350  return std::nullopt;
351  return lhs.urem(rhs);
352  });
353 }
354 
355 //===----------------------------------------------------------------------===//
356 // MaxSOp
357 //===----------------------------------------------------------------------===//
358 
359 OpFoldResult MaxSOp::fold(FoldAdaptor adaptor) {
360  return foldBinaryOpChecked(adaptor.getOperands(),
361  [](const APInt &lhs, const APInt &rhs) {
362  return lhs.sgt(rhs) ? lhs : rhs;
363  });
364 }
365 
366 LogicalResult MaxSOp::canonicalize(MaxSOp op, PatternRewriter &rewriter) {
367  return canonicalizeAssociativeCommutativeBinaryOp(op, rewriter);
368 }
369 
370 //===----------------------------------------------------------------------===//
371 // MaxUOp
372 //===----------------------------------------------------------------------===//
373 
374 OpFoldResult MaxUOp::fold(FoldAdaptor adaptor) {
375  return foldBinaryOpChecked(adaptor.getOperands(),
376  [](const APInt &lhs, const APInt &rhs) {
377  return lhs.ugt(rhs) ? lhs : rhs;
378  });
379 }
380 
381 LogicalResult MaxUOp::canonicalize(MaxUOp op, PatternRewriter &rewriter) {
382  return canonicalizeAssociativeCommutativeBinaryOp(op, rewriter);
383 }
384 
385 //===----------------------------------------------------------------------===//
386 // MinSOp
387 //===----------------------------------------------------------------------===//
388 
389 OpFoldResult MinSOp::fold(FoldAdaptor adaptor) {
390  return foldBinaryOpChecked(adaptor.getOperands(),
391  [](const APInt &lhs, const APInt &rhs) {
392  return lhs.slt(rhs) ? lhs : rhs;
393  });
394 }
395 
396 LogicalResult MinSOp::canonicalize(MinSOp op, PatternRewriter &rewriter) {
397  return canonicalizeAssociativeCommutativeBinaryOp(op, rewriter);
398 }
399 
400 //===----------------------------------------------------------------------===//
401 // MinUOp
402 //===----------------------------------------------------------------------===//
403 
404 OpFoldResult MinUOp::fold(FoldAdaptor adaptor) {
405  return foldBinaryOpChecked(adaptor.getOperands(),
406  [](const APInt &lhs, const APInt &rhs) {
407  return lhs.ult(rhs) ? lhs : rhs;
408  });
409 }
410 
411 LogicalResult MinUOp::canonicalize(MinUOp op, PatternRewriter &rewriter) {
412  return canonicalizeAssociativeCommutativeBinaryOp(op, rewriter);
413 }
414 
415 //===----------------------------------------------------------------------===//
416 // ShlOp
417 //===----------------------------------------------------------------------===//
418 
419 OpFoldResult ShlOp::fold(FoldAdaptor adaptor) {
420  return foldBinaryOpUnchecked(
421  adaptor.getOperands(),
422  [](const APInt &lhs, const APInt &rhs) -> std::optional<APInt> {
423  // We cannot fold if the RHS is greater than or equal to 32 because
424  // this would be UB in 32-bit systems but not on 64-bit systems. RHS is
425  // already treated as unsigned.
426  if (rhs.uge(32))
427  return {};
428  return lhs << rhs;
429  });
430 }
431 
432 //===----------------------------------------------------------------------===//
433 // ShrSOp
434 //===----------------------------------------------------------------------===//
435 
436 OpFoldResult ShrSOp::fold(FoldAdaptor adaptor) {
437  return foldBinaryOpChecked(
438  adaptor.getOperands(),
439  [](const APInt &lhs, const APInt &rhs) -> std::optional<APInt> {
440  // Don't fold if RHS is greater than or equal to 32.
441  if (rhs.uge(32))
442  return {};
443  return lhs.ashr(rhs);
444  });
445 }
446 
447 //===----------------------------------------------------------------------===//
448 // ShrUOp
449 //===----------------------------------------------------------------------===//
450 
451 OpFoldResult ShrUOp::fold(FoldAdaptor adaptor) {
452  return foldBinaryOpChecked(
453  adaptor.getOperands(),
454  [](const APInt &lhs, const APInt &rhs) -> std::optional<APInt> {
455  // Don't fold if RHS is greater than or equal to 32.
456  if (rhs.uge(32))
457  return {};
458  return lhs.lshr(rhs);
459  });
460 }
461 
462 //===----------------------------------------------------------------------===//
463 // AndOp
464 //===----------------------------------------------------------------------===//
465 
466 OpFoldResult AndOp::fold(FoldAdaptor adaptor) {
467  return foldBinaryOpUnchecked(
468  adaptor.getOperands(),
469  [](const APInt &lhs, const APInt &rhs) { return lhs & rhs; });
470 }
471 
472 LogicalResult AndOp::canonicalize(AndOp op, PatternRewriter &rewriter) {
473  return canonicalizeAssociativeCommutativeBinaryOp(op, rewriter);
474 }
475 
476 //===----------------------------------------------------------------------===//
477 // OrOp
478 //===----------------------------------------------------------------------===//
479 
480 OpFoldResult OrOp::fold(FoldAdaptor adaptor) {
481  return foldBinaryOpUnchecked(
482  adaptor.getOperands(),
483  [](const APInt &lhs, const APInt &rhs) { return lhs | rhs; });
484 }
485 
486 LogicalResult OrOp::canonicalize(OrOp op, PatternRewriter &rewriter) {
487  return canonicalizeAssociativeCommutativeBinaryOp(op, rewriter);
488 }
489 
490 //===----------------------------------------------------------------------===//
491 // XOrOp
492 //===----------------------------------------------------------------------===//
493 
494 OpFoldResult XOrOp::fold(FoldAdaptor adaptor) {
495  return foldBinaryOpUnchecked(
496  adaptor.getOperands(),
497  [](const APInt &lhs, const APInt &rhs) { return lhs ^ rhs; });
498 }
499 
500 LogicalResult XOrOp::canonicalize(XOrOp op, PatternRewriter &rewriter) {
501  return canonicalizeAssociativeCommutativeBinaryOp(op, rewriter);
502 }
503 
504 //===----------------------------------------------------------------------===//
505 // CastSOp
506 //===----------------------------------------------------------------------===//
507 
508 static OpFoldResult
510  function_ref<APInt(const APInt &, unsigned)> extFn,
511  function_ref<APInt(const APInt &, unsigned)> extOrTruncFn) {
512  auto attr = dyn_cast_if_present<IntegerAttr>(input);
513  if (!attr)
514  return {};
515  const APInt &value = attr.getValue();
516 
517  if (isa<IndexType>(type)) {
518  // When casting to an index type, perform the cast assuming a 64-bit target.
519  // The result can be truncated to 32 bits as needed and always be correct.
520  // This is because `cast32(cast64(value)) == cast32(value)`.
521  APInt result = extOrTruncFn(value, 64);
522  return IntegerAttr::get(type, result);
523  }
524 
525  // When casting from an index type, we must ensure the results respect
526  // `cast_t(value) == cast_t(trunc32(value))`.
527  auto intType = cast<IntegerType>(type);
528  unsigned width = intType.getWidth();
529 
530  // If the result type is at most 32 bits, then the cast can always be folded
531  // because it is always a truncation.
532  if (width <= 32) {
533  APInt result = value.trunc(width);
534  return IntegerAttr::get(type, result);
535  }
536 
537  // If the result type is at least 64 bits, then the cast is always a
538  // extension. The results will differ if `trunc32(value) != value)`.
539  if (width >= 64) {
540  if (extFn(value.trunc(32), 64) != value)
541  return {};
542  APInt result = extFn(value, width);
543  return IntegerAttr::get(type, result);
544  }
545 
546  // Otherwise, we just have to check the property directly.
547  APInt result = value.trunc(width);
548  if (result != extFn(value.trunc(32), width))
549  return {};
550  return IntegerAttr::get(type, result);
551 }
552 
553 bool CastSOp::areCastCompatible(TypeRange lhsTypes, TypeRange rhsTypes) {
554  return llvm::isa<IndexType>(lhsTypes.front()) !=
555  llvm::isa<IndexType>(rhsTypes.front());
556 }
557 
558 OpFoldResult CastSOp::fold(FoldAdaptor adaptor) {
559  return foldCastOp(
560  adaptor.getInput(), getType(),
561  [](const APInt &x, unsigned width) { return x.sext(width); },
562  [](const APInt &x, unsigned width) { return x.sextOrTrunc(width); });
563 }
564 
565 //===----------------------------------------------------------------------===//
566 // CastUOp
567 //===----------------------------------------------------------------------===//
568 
569 bool CastUOp::areCastCompatible(TypeRange lhsTypes, TypeRange rhsTypes) {
570  return llvm::isa<IndexType>(lhsTypes.front()) !=
571  llvm::isa<IndexType>(rhsTypes.front());
572 }
573 
574 OpFoldResult CastUOp::fold(FoldAdaptor adaptor) {
575  return foldCastOp(
576  adaptor.getInput(), getType(),
577  [](const APInt &x, unsigned width) { return x.zext(width); },
578  [](const APInt &x, unsigned width) { return x.zextOrTrunc(width); });
579 }
580 
581 //===----------------------------------------------------------------------===//
582 // CmpOp
583 //===----------------------------------------------------------------------===//
584 
585 /// Compare two integers according to the comparison predicate.
586 bool compareIndices(const APInt &lhs, const APInt &rhs,
587  IndexCmpPredicate pred) {
588  switch (pred) {
589  case IndexCmpPredicate::EQ:
590  return lhs.eq(rhs);
591  case IndexCmpPredicate::NE:
592  return lhs.ne(rhs);
593  case IndexCmpPredicate::SGE:
594  return lhs.sge(rhs);
595  case IndexCmpPredicate::SGT:
596  return lhs.sgt(rhs);
597  case IndexCmpPredicate::SLE:
598  return lhs.sle(rhs);
599  case IndexCmpPredicate::SLT:
600  return lhs.slt(rhs);
601  case IndexCmpPredicate::UGE:
602  return lhs.uge(rhs);
603  case IndexCmpPredicate::UGT:
604  return lhs.ugt(rhs);
605  case IndexCmpPredicate::ULE:
606  return lhs.ule(rhs);
607  case IndexCmpPredicate::ULT:
608  return lhs.ult(rhs);
609  }
610  llvm_unreachable("unhandled IndexCmpPredicate predicate");
611 }
612 
613 /// `cmp(max/min(x, cstA), cstB)` can be folded to a constant depending on the
614 /// values of `cstA` and `cstB`, the max or min operation, and the comparison
615 /// predicate. Check whether the value folds in both 32-bit and 64-bit
616 /// arithmetic and to the same value.
617 static std::optional<bool> foldCmpOfMaxOrMin(Operation *lhsOp,
618  const APInt &cstA,
619  const APInt &cstB, unsigned width,
620  IndexCmpPredicate pred) {
622  .Case([&](MinSOp op) {
623  return ConstantIntRanges::fromSigned(
624  APInt::getSignedMinValue(width), cstA);
625  })
626  .Case([&](MinUOp op) {
627  return ConstantIntRanges::fromUnsigned(
628  APInt::getMinValue(width), cstA);
629  })
630  .Case([&](MaxSOp op) {
631  return ConstantIntRanges::fromSigned(
632  cstA, APInt::getSignedMaxValue(width));
633  })
634  .Case([&](MaxUOp op) {
635  return ConstantIntRanges::fromUnsigned(
636  cstA, APInt::getMaxValue(width));
637  });
638  return intrange::evaluatePred(static_cast<intrange::CmpPredicate>(pred),
639  lhsRange, ConstantIntRanges::constant(cstB));
640 }
641 
642 /// Return the result of `cmp(pred, x, x)`
643 static bool compareSameArgs(IndexCmpPredicate pred) {
644  switch (pred) {
645  case IndexCmpPredicate::EQ:
646  case IndexCmpPredicate::SGE:
647  case IndexCmpPredicate::SLE:
648  case IndexCmpPredicate::UGE:
649  case IndexCmpPredicate::ULE:
650  return true;
651  case IndexCmpPredicate::NE:
652  case IndexCmpPredicate::SGT:
653  case IndexCmpPredicate::SLT:
654  case IndexCmpPredicate::UGT:
655  case IndexCmpPredicate::ULT:
656  return false;
657  }
658  llvm_unreachable("unknown predicate in compareSameArgs");
659 }
660 
661 OpFoldResult CmpOp::fold(FoldAdaptor adaptor) {
662  // Attempt to fold if both inputs are constant.
663  auto lhs = dyn_cast_if_present<IntegerAttr>(adaptor.getLhs());
664  auto rhs = dyn_cast_if_present<IntegerAttr>(adaptor.getRhs());
665  if (lhs && rhs) {
666  // Perform the comparison in 64-bit and 32-bit.
667  bool result64 = compareIndices(lhs.getValue(), rhs.getValue(), getPred());
668  bool result32 = compareIndices(lhs.getValue().trunc(32),
669  rhs.getValue().trunc(32), getPred());
670  if (result64 == result32)
671  return BoolAttr::get(getContext(), result64);
672  }
673 
674  // Fold `cmp(max/min(x, cstA), cstB)`.
675  Operation *lhsOp = getLhs().getDefiningOp();
676  IntegerAttr cstA;
677  if (isa_and_nonnull<MinSOp, MinUOp, MaxSOp, MaxUOp>(lhsOp) &&
678  matchPattern(lhsOp->getOperand(1), m_Constant(&cstA)) && rhs) {
679  std::optional<bool> result64 = foldCmpOfMaxOrMin(
680  lhsOp, cstA.getValue(), rhs.getValue(), 64, getPred());
681  std::optional<bool> result32 =
682  foldCmpOfMaxOrMin(lhsOp, cstA.getValue().trunc(32),
683  rhs.getValue().trunc(32), 32, getPred());
684  // Fold if the 32-bit and 64-bit results are the same.
685  if (result64 && result32 && *result64 == *result32)
686  return BoolAttr::get(getContext(), *result64);
687  }
688 
689  // Fold `cmp(x, x)`
690  if (getLhs() == getRhs())
691  return BoolAttr::get(getContext(), compareSameArgs(getPred()));
692 
693  return {};
694 }
695 
696 /// Canonicalize
697 /// `x - y cmp 0` to `x cmp y`. or `x - y cmp 0` to `x cmp y`.
698 /// `0 cmp x - y` to `y cmp x`. or `0 cmp x - y` to `y cmp x`.
699 LogicalResult CmpOp::canonicalize(CmpOp op, PatternRewriter &rewriter) {
700  IntegerAttr cmpRhs;
701  IntegerAttr cmpLhs;
702 
703  bool rhsIsZero = matchPattern(op.getRhs(), m_Constant(&cmpRhs)) &&
704  cmpRhs.getValue().isZero();
705  bool lhsIsZero = matchPattern(op.getLhs(), m_Constant(&cmpLhs)) &&
706  cmpLhs.getValue().isZero();
707  if (!rhsIsZero && !lhsIsZero)
708  return rewriter.notifyMatchFailure(op.getLoc(),
709  "cmp is not comparing something with 0");
710  SubOp subOp = rhsIsZero ? op.getLhs().getDefiningOp<index::SubOp>()
711  : op.getRhs().getDefiningOp<index::SubOp>();
712  if (!subOp)
713  return rewriter.notifyMatchFailure(
714  op.getLoc(), "non-zero operand is not a result of subtraction");
715 
716  index::CmpOp newCmp;
717  if (rhsIsZero)
718  newCmp = index::CmpOp::create(rewriter, op.getLoc(), op.getPred(),
719  subOp.getLhs(), subOp.getRhs());
720  else
721  newCmp = index::CmpOp::create(rewriter, op.getLoc(), op.getPred(),
722  subOp.getRhs(), subOp.getLhs());
723  rewriter.replaceOp(op, newCmp);
724  return success();
725 }
726 
727 //===----------------------------------------------------------------------===//
728 // ConstantOp
729 //===----------------------------------------------------------------------===//
730 
731 void ConstantOp::getAsmResultNames(
732  function_ref<void(Value, StringRef)> setNameFn) {
733  SmallString<32> specialNameBuffer;
734  llvm::raw_svector_ostream specialName(specialNameBuffer);
735  specialName << "idx" << getValueAttr().getValue();
736  setNameFn(getResult(), specialName.str());
737 }
738 
739 OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) { return getValueAttr(); }
740 
741 void ConstantOp::build(OpBuilder &b, OperationState &state, int64_t value) {
742  build(b, state, b.getIndexType(), b.getIndexAttr(value));
743 }
744 
745 //===----------------------------------------------------------------------===//
746 // BoolConstantOp
747 //===----------------------------------------------------------------------===//
748 
749 OpFoldResult BoolConstantOp::fold(FoldAdaptor adaptor) {
750  return getValueAttr();
751 }
752 
753 void BoolConstantOp::getAsmResultNames(
754  function_ref<void(Value, StringRef)> setNameFn) {
755  setNameFn(getResult(), getValue() ? "true" : "false");
756 }
757 
758 //===----------------------------------------------------------------------===//
759 // ODS-Generated Definitions
760 //===----------------------------------------------------------------------===//
761 
762 #define GET_OP_CLASSES
763 #include "mlir/Dialect/Index/IR/IndexOps.cpp.inc"
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static OpFoldResult foldBinaryOpUnchecked(ArrayRef< Attribute > operands, function_ref< std::optional< APInt >(const APInt &, const APInt &)> calculate)
Fold an index operation irrespective of the target bitwidth.
Definition: IndexOps.cpp:68
LogicalResult canonicalizeAssociativeCommutativeBinaryOp(BinaryOp op, PatternRewriter &rewriter)
Helper for associative and commutative binary ops that can be transformed: x = op(v,...
Definition: IndexOps.cpp:125
bool compareIndices(const APInt &lhs, const APInt &rhs, IndexCmpPredicate pred)
Compare two integers according to the comparison predicate.
Definition: IndexOps.cpp:586
static OpFoldResult foldBinaryOpChecked(ArrayRef< Attribute > operands, function_ref< std::optional< APInt >(const APInt &, const APInt &lhs)> calculate)
Fold an index operation only if the truncated 64-bit result matches the 32-bit result for operations ...
Definition: IndexOps.cpp:94
static std::optional< bool > foldCmpOfMaxOrMin(Operation *lhsOp, const APInt &cstA, const APInt &cstB, unsigned width, IndexCmpPredicate pred)
cmp(max/min(x, cstA), cstB) can be folded to a constant depending on the values of cstA and cstB,...
Definition: IndexOps.cpp:617
static OpFoldResult foldCastOp(Attribute input, Type type, function_ref< APInt(const APInt &, unsigned)> extFn, function_ref< APInt(const APInt &, unsigned)> extOrTruncFn)
Definition: IndexOps.cpp:509
static std::optional< APInt > calculateCeilDivS(const APInt &n, const APInt &m)
Compute ceildivs(n, m) as x = m > 0 ? -1 : 1 and then n*m > 0 ? (n+x)/m + 1 : -(-n/m).
Definition: IndexOps.cpp:250
static bool compareSameArgs(IndexCmpPredicate pred)
Return the result of cmp(pred, x, x)
Definition: IndexOps.cpp:643
static std::optional< APInt > calculateFloorDivS(const APInt &n, const APInt &m)
Compute floordivs(n, m) as x = m < 0 ? 1 : -1 and then n*m < 0 ? -1 - (x-n)/m : n/m.
Definition: IndexOps.cpp:301
static MLIRContext * getContext(OpFoldResult val)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:103
IndexType getIndexType()
Definition: Builders.cpp:50
A set of arbitrary-precision integers representing bounds on a given integer value.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
This class helps build Operations.
Definition: Builders.h:205
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:517
This class represents a single result from folding an operation.
Definition: OpDefinition.h:272
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:350
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:769
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:702
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:519
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isSignlessInteger() const
Return true if this is a signless integer type (with the specified width).
Definition: Types.cpp:64
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
std::optional< bool > evaluatePred(CmpPredicate pred, const ConstantIntRanges &lhs, const ConstantIntRanges &rhs)
Returns a boolean value if pred is statically true or false for anypossible inputs falling within lhs...
CmpPredicate
Copy of the enum from arith and index to allow the common integer range infrastructure to not depend ...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:304
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
This represents an operation in an abstracted form, suitable for use with the builder APIs.