MLIR  19.0.0git
IndexOps.cpp
Go to the documentation of this file.
1 //===- IndexOps.cpp - Index operation definitions --------------------------==//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
12 #include "mlir/IR/Builders.h"
13 #include "mlir/IR/Matchers.h"
15 #include "mlir/IR/PatternMatch.h"
17 #include "llvm/ADT/SmallString.h"
18 #include "llvm/ADT/TypeSwitch.h"
19 
20 using namespace mlir;
21 using namespace mlir::index;
22 
23 //===----------------------------------------------------------------------===//
24 // IndexDialect
25 //===----------------------------------------------------------------------===//
26 
27 void IndexDialect::registerOperations() {
28  addOperations<
29 #define GET_OP_LIST
30 #include "mlir/Dialect/Index/IR/IndexOps.cpp.inc"
31  >();
32 }
33 
35  Type type, Location loc) {
36  // Materialize bool constants as `i1`.
37  if (auto boolValue = dyn_cast<BoolAttr>(value)) {
38  if (!type.isSignlessInteger(1))
39  return nullptr;
40  return b.create<BoolConstantOp>(loc, type, boolValue);
41  }
42 
43  // Materialize integer attributes as `index`.
44  if (auto indexValue = dyn_cast<IntegerAttr>(value)) {
45  if (!llvm::isa<IndexType>(indexValue.getType()) ||
46  !llvm::isa<IndexType>(type))
47  return nullptr;
48  assert(indexValue.getValue().getBitWidth() ==
49  IndexType::kInternalStorageBitWidth);
50  return b.create<ConstantOp>(loc, indexValue);
51  }
52 
53  return nullptr;
54 }
55 
56 //===----------------------------------------------------------------------===//
57 // Fold Utilities
58 //===----------------------------------------------------------------------===//
59 
60 /// Fold an index operation irrespective of the target bitwidth. The
61 /// operation must satisfy the property:
62 ///
63 /// ```
64 /// trunc(f(a, b)) = f(trunc(a), trunc(b))
65 /// ```
66 ///
67 /// For all values of `a` and `b`. The function accepts a lambda that computes
68 /// the integer result, which in turn must satisfy the above property.
70  ArrayRef<Attribute> operands,
71  function_ref<std::optional<APInt>(const APInt &, const APInt &)>
72  calculate) {
73  assert(operands.size() == 2 && "binary operation expected 2 operands");
74  auto lhs = dyn_cast_if_present<IntegerAttr>(operands[0]);
75  auto rhs = dyn_cast_if_present<IntegerAttr>(operands[1]);
76  if (!lhs || !rhs)
77  return {};
78 
79  std::optional<APInt> result = calculate(lhs.getValue(), rhs.getValue());
80  if (!result)
81  return {};
82  assert(result->trunc(32) ==
83  calculate(lhs.getValue().trunc(32), rhs.getValue().trunc(32)));
84  return IntegerAttr::get(IndexType::get(lhs.getContext()), *result);
85 }
86 
87 /// Fold an index operation only if the truncated 64-bit result matches the
88 /// 32-bit result for operations that don't satisfy the above property. These
89 /// are operations where the upper bits of the operands can affect the lower
90 /// bits of the results.
91 ///
92 /// The function accepts a lambda that computes the integer result in both
93 /// 64-bit and 32-bit. If either call returns `std::nullopt`, the operation is
94 /// not folded.
96  ArrayRef<Attribute> operands,
97  function_ref<std::optional<APInt>(const APInt &, const APInt &lhs)>
98  calculate) {
99  assert(operands.size() == 2 && "binary operation expected 2 operands");
100  auto lhs = dyn_cast_if_present<IntegerAttr>(operands[0]);
101  auto rhs = dyn_cast_if_present<IntegerAttr>(operands[1]);
102  // Only fold index operands.
103  if (!lhs || !rhs)
104  return {};
105 
106  // Compute the 64-bit result and the 32-bit result.
107  std::optional<APInt> result64 = calculate(lhs.getValue(), rhs.getValue());
108  if (!result64)
109  return {};
110  std::optional<APInt> result32 =
111  calculate(lhs.getValue().trunc(32), rhs.getValue().trunc(32));
112  if (!result32)
113  return {};
114  // Compare the truncated 64-bit result to the 32-bit result.
115  if (result64->trunc(32) != *result32)
116  return {};
117  // The operation can be folded for these particular operands.
118  return IntegerAttr::get(IndexType::get(lhs.getContext()), *result64);
119 }
120 
121 //===----------------------------------------------------------------------===//
122 // AddOp
123 //===----------------------------------------------------------------------===//
124 
125 OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
126  if (OpFoldResult result = foldBinaryOpUnchecked(
127  adaptor.getOperands(),
128  [](const APInt &lhs, const APInt &rhs) { return lhs + rhs; }))
129  return result;
130 
131  if (auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) {
132  // Fold `add(x, 0) -> x`.
133  if (rhs.getValue().isZero())
134  return getLhs();
135  }
136 
137  return {};
138 }
139 
140 //===----------------------------------------------------------------------===//
141 // SubOp
142 //===----------------------------------------------------------------------===//
143 
144 OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
145  if (OpFoldResult result = foldBinaryOpUnchecked(
146  adaptor.getOperands(),
147  [](const APInt &lhs, const APInt &rhs) { return lhs - rhs; }))
148  return result;
149 
150  if (auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) {
151  // Fold `sub(x, 0) -> x`.
152  if (rhs.getValue().isZero())
153  return getLhs();
154  }
155 
156  return {};
157 }
158 
159 //===----------------------------------------------------------------------===//
160 // MulOp
161 //===----------------------------------------------------------------------===//
162 
163 OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
164  if (OpFoldResult result = foldBinaryOpUnchecked(
165  adaptor.getOperands(),
166  [](const APInt &lhs, const APInt &rhs) { return lhs * rhs; }))
167  return result;
168 
169  if (auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) {
170  // Fold `mul(x, 1) -> x`.
171  if (rhs.getValue().isOne())
172  return getLhs();
173  // Fold `mul(x, 0) -> 0`.
174  if (rhs.getValue().isZero())
175  return rhs;
176  }
177 
178  return {};
179 }
180 
181 //===----------------------------------------------------------------------===//
182 // DivSOp
183 //===----------------------------------------------------------------------===//
184 
185 OpFoldResult DivSOp::fold(FoldAdaptor adaptor) {
186  return foldBinaryOpChecked(
187  adaptor.getOperands(),
188  [](const APInt &lhs, const APInt &rhs) -> std::optional<APInt> {
189  // Don't fold division by zero.
190  if (rhs.isZero())
191  return std::nullopt;
192  return lhs.sdiv(rhs);
193  });
194 }
195 
196 //===----------------------------------------------------------------------===//
197 // DivUOp
198 //===----------------------------------------------------------------------===//
199 
200 OpFoldResult DivUOp::fold(FoldAdaptor adaptor) {
201  return foldBinaryOpChecked(
202  adaptor.getOperands(),
203  [](const APInt &lhs, const APInt &rhs) -> std::optional<APInt> {
204  // Don't fold division by zero.
205  if (rhs.isZero())
206  return std::nullopt;
207  return lhs.udiv(rhs);
208  });
209 }
210 
211 //===----------------------------------------------------------------------===//
212 // CeilDivSOp
213 //===----------------------------------------------------------------------===//
214 
215 /// Compute `ceildivs(n, m)` as `x = m > 0 ? -1 : 1` and then
216 /// `n*m > 0 ? (n+x)/m + 1 : -(-n/m)`.
217 static std::optional<APInt> calculateCeilDivS(const APInt &n, const APInt &m) {
218  // Don't fold division by zero.
219  if (m.isZero())
220  return std::nullopt;
221  // Short-circuit the zero case.
222  if (n.isZero())
223  return n;
224 
225  bool mGtZ = m.sgt(0);
226  if (n.sgt(0) != mGtZ) {
227  // If the operands have different signs, compute the negative result. Signed
228  // division overflow is not possible, since if `m == -1`, `n` can be at most
229  // `INT_MAX`, and `-INT_MAX != INT_MIN` in two's complement.
230  return -(-n).sdiv(m);
231  }
232  // Otherwise, compute the positive result. Signed division overflow is not
233  // possible since if `m == -1`, `x` will be `1`.
234  int64_t x = mGtZ ? -1 : 1;
235  return (n + x).sdiv(m) + 1;
236 }
237 
238 OpFoldResult CeilDivSOp::fold(FoldAdaptor adaptor) {
239  return foldBinaryOpChecked(adaptor.getOperands(), calculateCeilDivS);
240 }
241 
242 //===----------------------------------------------------------------------===//
243 // CeilDivUOp
244 //===----------------------------------------------------------------------===//
245 
246 OpFoldResult CeilDivUOp::fold(FoldAdaptor adaptor) {
247  // Compute `ceildivu(n, m)` as `n == 0 ? 0 : (n-1)/m + 1`.
248  return foldBinaryOpChecked(
249  adaptor.getOperands(),
250  [](const APInt &n, const APInt &m) -> std::optional<APInt> {
251  // Don't fold division by zero.
252  if (m.isZero())
253  return std::nullopt;
254  // Short-circuit the zero case.
255  if (n.isZero())
256  return n;
257 
258  return (n - 1).udiv(m) + 1;
259  });
260 }
261 
262 //===----------------------------------------------------------------------===//
263 // FloorDivSOp
264 //===----------------------------------------------------------------------===//
265 
266 /// Compute `floordivs(n, m)` as `x = m < 0 ? 1 : -1` and then
267 /// `n*m < 0 ? -1 - (x-n)/m : n/m`.
268 static std::optional<APInt> calculateFloorDivS(const APInt &n, const APInt &m) {
269  // Don't fold division by zero.
270  if (m.isZero())
271  return std::nullopt;
272  // Short-circuit the zero case.
273  if (n.isZero())
274  return n;
275 
276  bool mLtZ = m.slt(0);
277  if (n.slt(0) == mLtZ) {
278  // If the operands have the same sign, compute the positive result.
279  return n.sdiv(m);
280  }
281  // If the operands have different signs, compute the negative result. Signed
282  // division overflow is not possible since if `m == -1`, `x` will be 1 and
283  // `n` can be at most `INT_MAX`.
284  int64_t x = mLtZ ? 1 : -1;
285  return -1 - (x - n).sdiv(m);
286 }
287 
288 OpFoldResult FloorDivSOp::fold(FoldAdaptor adaptor) {
289  return foldBinaryOpChecked(adaptor.getOperands(), calculateFloorDivS);
290 }
291 
292 //===----------------------------------------------------------------------===//
293 // RemSOp
294 //===----------------------------------------------------------------------===//
295 
296 OpFoldResult RemSOp::fold(FoldAdaptor adaptor) {
297  return foldBinaryOpChecked(
298  adaptor.getOperands(),
299  [](const APInt &lhs, const APInt &rhs) -> std::optional<APInt> {
300  // Don't fold division by zero.
301  if (rhs.isZero())
302  return std::nullopt;
303  return lhs.srem(rhs);
304  });
305 }
306 
307 //===----------------------------------------------------------------------===//
308 // RemUOp
309 //===----------------------------------------------------------------------===//
310 
311 OpFoldResult RemUOp::fold(FoldAdaptor adaptor) {
312  return foldBinaryOpChecked(
313  adaptor.getOperands(),
314  [](const APInt &lhs, const APInt &rhs) -> std::optional<APInt> {
315  // Don't fold division by zero.
316  if (rhs.isZero())
317  return std::nullopt;
318  return lhs.urem(rhs);
319  });
320 }
321 
322 //===----------------------------------------------------------------------===//
323 // MaxSOp
324 //===----------------------------------------------------------------------===//
325 
326 OpFoldResult MaxSOp::fold(FoldAdaptor adaptor) {
327  return foldBinaryOpChecked(adaptor.getOperands(),
328  [](const APInt &lhs, const APInt &rhs) {
329  return lhs.sgt(rhs) ? lhs : rhs;
330  });
331 }
332 
333 //===----------------------------------------------------------------------===//
334 // MaxUOp
335 //===----------------------------------------------------------------------===//
336 
337 OpFoldResult MaxUOp::fold(FoldAdaptor adaptor) {
338  return foldBinaryOpChecked(adaptor.getOperands(),
339  [](const APInt &lhs, const APInt &rhs) {
340  return lhs.ugt(rhs) ? lhs : rhs;
341  });
342 }
343 
344 //===----------------------------------------------------------------------===//
345 // MinSOp
346 //===----------------------------------------------------------------------===//
347 
348 OpFoldResult MinSOp::fold(FoldAdaptor adaptor) {
349  return foldBinaryOpChecked(adaptor.getOperands(),
350  [](const APInt &lhs, const APInt &rhs) {
351  return lhs.slt(rhs) ? lhs : rhs;
352  });
353 }
354 
355 //===----------------------------------------------------------------------===//
356 // MinUOp
357 //===----------------------------------------------------------------------===//
358 
359 OpFoldResult MinUOp::fold(FoldAdaptor adaptor) {
360  return foldBinaryOpChecked(adaptor.getOperands(),
361  [](const APInt &lhs, const APInt &rhs) {
362  return lhs.ult(rhs) ? lhs : rhs;
363  });
364 }
365 
366 //===----------------------------------------------------------------------===//
367 // ShlOp
368 //===----------------------------------------------------------------------===//
369 
370 OpFoldResult ShlOp::fold(FoldAdaptor adaptor) {
371  return foldBinaryOpUnchecked(
372  adaptor.getOperands(),
373  [](const APInt &lhs, const APInt &rhs) -> std::optional<APInt> {
374  // We cannot fold if the RHS is greater than or equal to 32 because
375  // this would be UB in 32-bit systems but not on 64-bit systems. RHS is
376  // already treated as unsigned.
377  if (rhs.uge(32))
378  return {};
379  return lhs << rhs;
380  });
381 }
382 
383 //===----------------------------------------------------------------------===//
384 // ShrSOp
385 //===----------------------------------------------------------------------===//
386 
387 OpFoldResult ShrSOp::fold(FoldAdaptor adaptor) {
388  return foldBinaryOpChecked(
389  adaptor.getOperands(),
390  [](const APInt &lhs, const APInt &rhs) -> std::optional<APInt> {
391  // Don't fold if RHS is greater than or equal to 32.
392  if (rhs.uge(32))
393  return {};
394  return lhs.ashr(rhs);
395  });
396 }
397 
398 //===----------------------------------------------------------------------===//
399 // ShrUOp
400 //===----------------------------------------------------------------------===//
401 
402 OpFoldResult ShrUOp::fold(FoldAdaptor adaptor) {
403  return foldBinaryOpChecked(
404  adaptor.getOperands(),
405  [](const APInt &lhs, const APInt &rhs) -> std::optional<APInt> {
406  // Don't fold if RHS is greater than or equal to 32.
407  if (rhs.uge(32))
408  return {};
409  return lhs.lshr(rhs);
410  });
411 }
412 
413 //===----------------------------------------------------------------------===//
414 // AndOp
415 //===----------------------------------------------------------------------===//
416 
417 OpFoldResult AndOp::fold(FoldAdaptor adaptor) {
418  return foldBinaryOpUnchecked(
419  adaptor.getOperands(),
420  [](const APInt &lhs, const APInt &rhs) { return lhs & rhs; });
421 }
422 
423 //===----------------------------------------------------------------------===//
424 // OrOp
425 //===----------------------------------------------------------------------===//
426 
427 OpFoldResult OrOp::fold(FoldAdaptor adaptor) {
428  return foldBinaryOpUnchecked(
429  adaptor.getOperands(),
430  [](const APInt &lhs, const APInt &rhs) { return lhs | rhs; });
431 }
432 
433 //===----------------------------------------------------------------------===//
434 // XOrOp
435 //===----------------------------------------------------------------------===//
436 
437 OpFoldResult XOrOp::fold(FoldAdaptor adaptor) {
438  return foldBinaryOpUnchecked(
439  adaptor.getOperands(),
440  [](const APInt &lhs, const APInt &rhs) { return lhs ^ rhs; });
441 }
442 
443 //===----------------------------------------------------------------------===//
444 // CastSOp
445 //===----------------------------------------------------------------------===//
446 
447 static OpFoldResult
449  function_ref<APInt(const APInt &, unsigned)> extFn,
450  function_ref<APInt(const APInt &, unsigned)> extOrTruncFn) {
451  auto attr = dyn_cast_if_present<IntegerAttr>(input);
452  if (!attr)
453  return {};
454  const APInt &value = attr.getValue();
455 
456  if (isa<IndexType>(type)) {
457  // When casting to an index type, perform the cast assuming a 64-bit target.
458  // The result can be truncated to 32 bits as needed and always be correct.
459  // This is because `cast32(cast64(value)) == cast32(value)`.
460  APInt result = extOrTruncFn(value, 64);
461  return IntegerAttr::get(type, result);
462  }
463 
464  // When casting from an index type, we must ensure the results respect
465  // `cast_t(value) == cast_t(trunc32(value))`.
466  auto intType = cast<IntegerType>(type);
467  unsigned width = intType.getWidth();
468 
469  // If the result type is at most 32 bits, then the cast can always be folded
470  // because it is always a truncation.
471  if (width <= 32) {
472  APInt result = value.trunc(width);
473  return IntegerAttr::get(type, result);
474  }
475 
476  // If the result type is at least 64 bits, then the cast is always a
477  // extension. The results will differ if `trunc32(value) != value)`.
478  if (width >= 64) {
479  if (extFn(value.trunc(32), 64) != value)
480  return {};
481  APInt result = extFn(value, width);
482  return IntegerAttr::get(type, result);
483  }
484 
485  // Otherwise, we just have to check the property directly.
486  APInt result = value.trunc(width);
487  if (result != extFn(value.trunc(32), width))
488  return {};
489  return IntegerAttr::get(type, result);
490 }
491 
492 bool CastSOp::areCastCompatible(TypeRange lhsTypes, TypeRange rhsTypes) {
493  return llvm::isa<IndexType>(lhsTypes.front()) !=
494  llvm::isa<IndexType>(rhsTypes.front());
495 }
496 
497 OpFoldResult CastSOp::fold(FoldAdaptor adaptor) {
498  return foldCastOp(
499  adaptor.getInput(), getType(),
500  [](const APInt &x, unsigned width) { return x.sext(width); },
501  [](const APInt &x, unsigned width) { return x.sextOrTrunc(width); });
502 }
503 
504 //===----------------------------------------------------------------------===//
505 // CastUOp
506 //===----------------------------------------------------------------------===//
507 
508 bool CastUOp::areCastCompatible(TypeRange lhsTypes, TypeRange rhsTypes) {
509  return llvm::isa<IndexType>(lhsTypes.front()) !=
510  llvm::isa<IndexType>(rhsTypes.front());
511 }
512 
513 OpFoldResult CastUOp::fold(FoldAdaptor adaptor) {
514  return foldCastOp(
515  adaptor.getInput(), getType(),
516  [](const APInt &x, unsigned width) { return x.zext(width); },
517  [](const APInt &x, unsigned width) { return x.zextOrTrunc(width); });
518 }
519 
520 //===----------------------------------------------------------------------===//
521 // CmpOp
522 //===----------------------------------------------------------------------===//
523 
524 /// Compare two integers according to the comparison predicate.
525 bool compareIndices(const APInt &lhs, const APInt &rhs,
526  IndexCmpPredicate pred) {
527  switch (pred) {
528  case IndexCmpPredicate::EQ:
529  return lhs.eq(rhs);
530  case IndexCmpPredicate::NE:
531  return lhs.ne(rhs);
532  case IndexCmpPredicate::SGE:
533  return lhs.sge(rhs);
534  case IndexCmpPredicate::SGT:
535  return lhs.sgt(rhs);
536  case IndexCmpPredicate::SLE:
537  return lhs.sle(rhs);
538  case IndexCmpPredicate::SLT:
539  return lhs.slt(rhs);
540  case IndexCmpPredicate::UGE:
541  return lhs.uge(rhs);
542  case IndexCmpPredicate::UGT:
543  return lhs.ugt(rhs);
544  case IndexCmpPredicate::ULE:
545  return lhs.ule(rhs);
546  case IndexCmpPredicate::ULT:
547  return lhs.ult(rhs);
548  }
549  llvm_unreachable("unhandled IndexCmpPredicate predicate");
550 }
551 
552 /// `cmp(max/min(x, cstA), cstB)` can be folded to a constant depending on the
553 /// values of `cstA` and `cstB`, the max or min operation, and the comparison
554 /// predicate. Check whether the value folds in both 32-bit and 64-bit
555 /// arithmetic and to the same value.
556 static std::optional<bool> foldCmpOfMaxOrMin(Operation *lhsOp,
557  const APInt &cstA,
558  const APInt &cstB, unsigned width,
559  IndexCmpPredicate pred) {
561  .Case([&](MinSOp op) {
562  return ConstantIntRanges::fromSigned(
563  APInt::getSignedMinValue(width), cstA);
564  })
565  .Case([&](MinUOp op) {
566  return ConstantIntRanges::fromUnsigned(
567  APInt::getMinValue(width), cstA);
568  })
569  .Case([&](MaxSOp op) {
570  return ConstantIntRanges::fromSigned(
571  cstA, APInt::getSignedMaxValue(width));
572  })
573  .Case([&](MaxUOp op) {
574  return ConstantIntRanges::fromUnsigned(
575  cstA, APInt::getMaxValue(width));
576  });
577  return intrange::evaluatePred(static_cast<intrange::CmpPredicate>(pred),
578  lhsRange, ConstantIntRanges::constant(cstB));
579 }
580 
581 /// Return the result of `cmp(pred, x, x)`
582 static bool compareSameArgs(IndexCmpPredicate pred) {
583  switch (pred) {
584  case IndexCmpPredicate::EQ:
585  case IndexCmpPredicate::SGE:
586  case IndexCmpPredicate::SLE:
587  case IndexCmpPredicate::UGE:
588  case IndexCmpPredicate::ULE:
589  return true;
590  case IndexCmpPredicate::NE:
591  case IndexCmpPredicate::SGT:
592  case IndexCmpPredicate::SLT:
593  case IndexCmpPredicate::UGT:
594  case IndexCmpPredicate::ULT:
595  return false;
596  }
597 }
598 
599 OpFoldResult CmpOp::fold(FoldAdaptor adaptor) {
600  // Attempt to fold if both inputs are constant.
601  auto lhs = dyn_cast_if_present<IntegerAttr>(adaptor.getLhs());
602  auto rhs = dyn_cast_if_present<IntegerAttr>(adaptor.getRhs());
603  if (lhs && rhs) {
604  // Perform the comparison in 64-bit and 32-bit.
605  bool result64 = compareIndices(lhs.getValue(), rhs.getValue(), getPred());
606  bool result32 = compareIndices(lhs.getValue().trunc(32),
607  rhs.getValue().trunc(32), getPred());
608  if (result64 == result32)
609  return BoolAttr::get(getContext(), result64);
610  }
611 
612  // Fold `cmp(max/min(x, cstA), cstB)`.
613  Operation *lhsOp = getLhs().getDefiningOp();
614  IntegerAttr cstA;
615  if (isa_and_nonnull<MinSOp, MinUOp, MaxSOp, MaxUOp>(lhsOp) &&
616  matchPattern(lhsOp->getOperand(1), m_Constant(&cstA)) && rhs) {
617  std::optional<bool> result64 = foldCmpOfMaxOrMin(
618  lhsOp, cstA.getValue(), rhs.getValue(), 64, getPred());
619  std::optional<bool> result32 =
620  foldCmpOfMaxOrMin(lhsOp, cstA.getValue().trunc(32),
621  rhs.getValue().trunc(32), 32, getPred());
622  // Fold if the 32-bit and 64-bit results are the same.
623  if (result64 && result32 && *result64 == *result32)
624  return BoolAttr::get(getContext(), *result64);
625  }
626 
627  // Fold `cmp(x, x)`
628  if (getLhs() == getRhs())
629  return BoolAttr::get(getContext(), compareSameArgs(getPred()));
630 
631  return {};
632 }
633 
634 /// Canonicalize
635 /// `x - y cmp 0` to `x cmp y`. or `x - y cmp 0` to `x cmp y`.
636 /// `0 cmp x - y` to `y cmp x`. or `0 cmp x - y` to `y cmp x`.
637 LogicalResult CmpOp::canonicalize(CmpOp op, PatternRewriter &rewriter) {
638  IntegerAttr cmpRhs;
639  IntegerAttr cmpLhs;
640 
641  bool rhsIsZero = matchPattern(op.getRhs(), m_Constant(&cmpRhs)) &&
642  cmpRhs.getValue().isZero();
643  bool lhsIsZero = matchPattern(op.getLhs(), m_Constant(&cmpLhs)) &&
644  cmpLhs.getValue().isZero();
645  if (!rhsIsZero && !lhsIsZero)
646  return rewriter.notifyMatchFailure(op.getLoc(),
647  "cmp is not comparing something with 0");
648  SubOp subOp = rhsIsZero ? op.getLhs().getDefiningOp<index::SubOp>()
649  : op.getRhs().getDefiningOp<index::SubOp>();
650  if (!subOp)
651  return rewriter.notifyMatchFailure(
652  op.getLoc(), "non-zero operand is not a result of subtraction");
653 
654  index::CmpOp newCmp;
655  if (rhsIsZero)
656  newCmp = rewriter.create<index::CmpOp>(op.getLoc(), op.getPred(),
657  subOp.getLhs(), subOp.getRhs());
658  else
659  newCmp = rewriter.create<index::CmpOp>(op.getLoc(), op.getPred(),
660  subOp.getRhs(), subOp.getLhs());
661  rewriter.replaceOp(op, newCmp);
662  return success();
663 }
664 
665 //===----------------------------------------------------------------------===//
666 // ConstantOp
667 //===----------------------------------------------------------------------===//
668 
669 void ConstantOp::getAsmResultNames(
670  function_ref<void(Value, StringRef)> setNameFn) {
671  SmallString<32> specialNameBuffer;
672  llvm::raw_svector_ostream specialName(specialNameBuffer);
673  specialName << "idx" << getValueAttr().getValue();
674  setNameFn(getResult(), specialName.str());
675 }
676 
677 OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) { return getValueAttr(); }
678 
679 void ConstantOp::build(OpBuilder &b, OperationState &state, int64_t value) {
680  build(b, state, b.getIndexType(), b.getIndexAttr(value));
681 }
682 
683 //===----------------------------------------------------------------------===//
684 // BoolConstantOp
685 //===----------------------------------------------------------------------===//
686 
687 OpFoldResult BoolConstantOp::fold(FoldAdaptor adaptor) {
688  return getValueAttr();
689 }
690 
691 void BoolConstantOp::getAsmResultNames(
692  function_ref<void(Value, StringRef)> setNameFn) {
693  setNameFn(getResult(), getValue() ? "true" : "false");
694 }
695 
696 //===----------------------------------------------------------------------===//
697 // ODS-Generated Definitions
698 //===----------------------------------------------------------------------===//
699 
700 #define GET_OP_CLASSES
701 #include "mlir/Dialect/Index/IR/IndexOps.cpp.inc"
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static OpFoldResult foldBinaryOpUnchecked(ArrayRef< Attribute > operands, function_ref< std::optional< APInt >(const APInt &, const APInt &)> calculate)
Fold an index operation irrespective of the target bitwidth.
Definition: IndexOps.cpp:69
bool compareIndices(const APInt &lhs, const APInt &rhs, IndexCmpPredicate pred)
Compare two integers according to the comparison predicate.
Definition: IndexOps.cpp:525
static OpFoldResult foldBinaryOpChecked(ArrayRef< Attribute > operands, function_ref< std::optional< APInt >(const APInt &, const APInt &lhs)> calculate)
Fold an index operation only if the truncated 64-bit result matches the 32-bit result for operations ...
Definition: IndexOps.cpp:95
static std::optional< bool > foldCmpOfMaxOrMin(Operation *lhsOp, const APInt &cstA, const APInt &cstB, unsigned width, IndexCmpPredicate pred)
cmp(max/min(x, cstA), cstB) can be folded to a constant depending on the values of cstA and cstB,...
Definition: IndexOps.cpp:556
static OpFoldResult foldCastOp(Attribute input, Type type, function_ref< APInt(const APInt &, unsigned)> extFn, function_ref< APInt(const APInt &, unsigned)> extOrTruncFn)
Definition: IndexOps.cpp:448
static std::optional< APInt > calculateCeilDivS(const APInt &n, const APInt &m)
Compute ceildivs(n, m) as x = m > 0 ? -1 : 1 and then n*m > 0 ? (n+x)/m + 1 : -(-n/m).
Definition: IndexOps.cpp:217
static bool compareSameArgs(IndexCmpPredicate pred)
Return the result of cmp(pred, x, x)
Definition: IndexOps.cpp:582
static std::optional< APInt > calculateFloorDivS(const APInt &n, const APInt &m)
Compute floordivs(n, m) as x = m < 0 ? 1 : -1 and then n*m < 0 ? -1 - (x-n)/m : n/m.
Definition: IndexOps.cpp:268
static MLIRContext * getContext(OpFoldResult val)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:124
IndexType getIndexType()
Definition: Builders.cpp:71
A set of arbitrary-precision integers representing bounds on a given integer value.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
This class helps build Operations.
Definition: Builders.h:209
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:345
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isSignlessInteger() const
Return true if this is a signless integer type (with the specified width).
Definition: Types.cpp:67
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
std::optional< bool > evaluatePred(CmpPredicate pred, const ConstantIntRanges &lhs, const ConstantIntRanges &rhs)
Returns a boolean value if pred is statically true or false for anypossible inputs falling within lhs...
CmpPredicate
Copy of the enum from arith and index to allow the common integer range infrastructure to not depend ...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:401
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:310
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
This represents an operation in an abstracted form, suitable for use with the builder APIs.