MLIR 22.0.0git
IndexOps.cpp
Go to the documentation of this file.
1//===- IndexOps.cpp - Index operation definitions --------------------------==//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
12#include "mlir/IR/Builders.h"
13#include "mlir/IR/Matchers.h"
16#include "llvm/ADT/SmallString.h"
17#include "llvm/ADT/TypeSwitch.h"
18
19using namespace mlir;
20using namespace mlir::index;
21
22//===----------------------------------------------------------------------===//
23// IndexDialect
24//===----------------------------------------------------------------------===//
25
26void IndexDialect::registerOperations() {
27 addOperations<
28#define GET_OP_LIST
29#include "mlir/Dialect/Index/IR/IndexOps.cpp.inc"
30 >();
31}
32
33Operation *IndexDialect::materializeConstant(OpBuilder &b, Attribute value,
34 Type type, Location loc) {
35 // Materialize bool constants as `i1`.
36 if (auto boolValue = dyn_cast<BoolAttr>(value)) {
37 if (!type.isSignlessInteger(1))
38 return nullptr;
39 return BoolConstantOp::create(b, loc, type, boolValue);
40 }
41
42 // Materialize integer attributes as `index`.
43 if (auto indexValue = dyn_cast<IntegerAttr>(value)) {
44 if (!llvm::isa<IndexType>(indexValue.getType()) ||
45 !llvm::isa<IndexType>(type))
46 return nullptr;
47 assert(indexValue.getValue().getBitWidth() ==
48 IndexType::kInternalStorageBitWidth);
49 return ConstantOp::create(b, loc, indexValue);
50 }
51
52 return nullptr;
53}
54
55//===----------------------------------------------------------------------===//
56// Fold Utilities
57//===----------------------------------------------------------------------===//
58
59/// Fold an index operation irrespective of the target bitwidth. The
60/// operation must satisfy the property:
61///
62/// ```
63/// trunc(f(a, b)) = f(trunc(a), trunc(b))
64/// ```
65///
66/// For all values of `a` and `b`. The function accepts a lambda that computes
67/// the integer result, which in turn must satisfy the above property.
69 ArrayRef<Attribute> operands,
70 function_ref<std::optional<APInt>(const APInt &, const APInt &)>
71 calculate) {
72 assert(operands.size() == 2 && "binary operation expected 2 operands");
73 auto lhs = dyn_cast_if_present<IntegerAttr>(operands[0]);
74 auto rhs = dyn_cast_if_present<IntegerAttr>(operands[1]);
75 if (!lhs || !rhs)
76 return {};
77
78 std::optional<APInt> result = calculate(lhs.getValue(), rhs.getValue());
79 if (!result)
80 return {};
81 assert(result->trunc(32) ==
82 calculate(lhs.getValue().trunc(32), rhs.getValue().trunc(32)));
83 return IntegerAttr::get(IndexType::get(lhs.getContext()), *result);
84}
85
86/// Fold an index operation only if the truncated 64-bit result matches the
87/// 32-bit result for operations that don't satisfy the above property. These
88/// are operations where the upper bits of the operands can affect the lower
89/// bits of the results.
90///
91/// The function accepts a lambda that computes the integer result in both
92/// 64-bit and 32-bit. If either call returns `std::nullopt`, the operation is
93/// not folded.
95 ArrayRef<Attribute> operands,
96 function_ref<std::optional<APInt>(const APInt &, const APInt &lhs)>
97 calculate) {
98 assert(operands.size() == 2 && "binary operation expected 2 operands");
99 auto lhs = dyn_cast_if_present<IntegerAttr>(operands[0]);
100 auto rhs = dyn_cast_if_present<IntegerAttr>(operands[1]);
101 // Only fold index operands.
102 if (!lhs || !rhs)
103 return {};
104
105 // Compute the 64-bit result and the 32-bit result.
106 std::optional<APInt> result64 = calculate(lhs.getValue(), rhs.getValue());
107 if (!result64)
108 return {};
109 std::optional<APInt> result32 =
110 calculate(lhs.getValue().trunc(32), rhs.getValue().trunc(32));
111 if (!result32)
112 return {};
113 // Compare the truncated 64-bit result to the 32-bit result.
114 if (result64->trunc(32) != *result32)
115 return {};
116 // The operation can be folded for these particular operands.
117 return IntegerAttr::get(IndexType::get(lhs.getContext()), *result64);
118}
119
120/// Helper for associative and commutative binary ops that can be transformed:
121/// `x = op(v, c1); y = op(x, c2)` -> `tmp = op(c1, c2); y = op(v, tmp)`
122/// where c1 and c2 are constants. It is expected that `tmp` will be folded.
123template <typename BinaryOp>
124LogicalResult
126 PatternRewriter &rewriter) {
127 if (!mlir::matchPattern(op.getRhs(), mlir::m_Constant()))
128 return rewriter.notifyMatchFailure(op.getLoc(), "RHS is not a constant");
129
130 auto lhsOp = op.getLhs().template getDefiningOp<BinaryOp>();
131 if (!lhsOp)
132 return rewriter.notifyMatchFailure(op.getLoc(), "LHS is not the same BinaryOp");
133
134 if (!mlir::matchPattern(lhsOp.getRhs(), mlir::m_Constant()))
135 return rewriter.notifyMatchFailure(op.getLoc(), "RHS of LHS op is not a constant");
136
137 Value c = rewriter.createOrFold<BinaryOp>(op->getLoc(), op.getRhs(),
138 lhsOp.getRhs());
139 if (c.getDefiningOp<BinaryOp>())
140 return rewriter.notifyMatchFailure(op.getLoc(), "new BinaryOp was not folded");
141
142 rewriter.replaceOpWithNewOp<BinaryOp>(op, lhsOp.getLhs(), c);
143 return success();
144}
145
146//===----------------------------------------------------------------------===//
147// AddOp
148//===----------------------------------------------------------------------===//
149
150OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
152 adaptor.getOperands(),
153 [](const APInt &lhs, const APInt &rhs) { return lhs + rhs; }))
154 return result;
155
156 if (auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) {
157 // Fold `add(x, 0) -> x`.
158 if (rhs.getValue().isZero())
159 return getLhs();
160 }
161
162 return {};
163}
164
165LogicalResult AddOp::canonicalize(AddOp op, PatternRewriter &rewriter) {
167}
168
169//===----------------------------------------------------------------------===//
170// SubOp
171//===----------------------------------------------------------------------===//
172
173OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
175 adaptor.getOperands(),
176 [](const APInt &lhs, const APInt &rhs) { return lhs - rhs; }))
177 return result;
178
179 if (auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) {
180 // Fold `sub(x, 0) -> x`.
181 if (rhs.getValue().isZero())
182 return getLhs();
183 }
184
185 return {};
186}
187
188//===----------------------------------------------------------------------===//
189// MulOp
190//===----------------------------------------------------------------------===//
191
192OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
194 adaptor.getOperands(),
195 [](const APInt &lhs, const APInt &rhs) { return lhs * rhs; }))
196 return result;
197
198 if (auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) {
199 // Fold `mul(x, 1) -> x`.
200 if (rhs.getValue().isOne())
201 return getLhs();
202 // Fold `mul(x, 0) -> 0`.
203 if (rhs.getValue().isZero())
204 return rhs;
205 }
206
207 return {};
208}
209
210LogicalResult MulOp::canonicalize(MulOp op, PatternRewriter &rewriter) {
212}
213
214//===----------------------------------------------------------------------===//
215// DivSOp
216//===----------------------------------------------------------------------===//
217
218OpFoldResult DivSOp::fold(FoldAdaptor adaptor) {
219 return foldBinaryOpChecked(
220 adaptor.getOperands(),
221 [](const APInt &lhs, const APInt &rhs) -> std::optional<APInt> {
222 // Don't fold division by zero.
223 if (rhs.isZero())
224 return std::nullopt;
225 return lhs.sdiv(rhs);
226 });
227}
228
229//===----------------------------------------------------------------------===//
230// DivUOp
231//===----------------------------------------------------------------------===//
232
233OpFoldResult DivUOp::fold(FoldAdaptor adaptor) {
234 return foldBinaryOpChecked(
235 adaptor.getOperands(),
236 [](const APInt &lhs, const APInt &rhs) -> std::optional<APInt> {
237 // Don't fold division by zero.
238 if (rhs.isZero())
239 return std::nullopt;
240 return lhs.udiv(rhs);
241 });
242}
243
244//===----------------------------------------------------------------------===//
245// CeilDivSOp
246//===----------------------------------------------------------------------===//
247
248/// Compute `ceildivs(n, m)` as `x = m > 0 ? -1 : 1` and then
249/// `n*m > 0 ? (n+x)/m + 1 : -(-n/m)`.
250static std::optional<APInt> calculateCeilDivS(const APInt &n, const APInt &m) {
251 // Don't fold division by zero.
252 if (m.isZero())
253 return std::nullopt;
254 // Short-circuit the zero case.
255 if (n.isZero())
256 return n;
257
258 bool mGtZ = m.sgt(0);
259 if (n.sgt(0) != mGtZ) {
260 // If the operands have different signs, compute the negative result. Signed
261 // division overflow is not possible, since if `m == -1`, `n` can be at most
262 // `INT_MAX`, and `-INT_MAX != INT_MIN` in two's complement.
263 return -(-n).sdiv(m);
264 }
265 // Otherwise, compute the positive result. Signed division overflow is not
266 // possible since if `m == -1`, `x` will be `1`.
267 int64_t x = mGtZ ? -1 : 1;
268 return (n + x).sdiv(m) + 1;
269}
270
271OpFoldResult CeilDivSOp::fold(FoldAdaptor adaptor) {
272 return foldBinaryOpChecked(adaptor.getOperands(), calculateCeilDivS);
273}
274
275//===----------------------------------------------------------------------===//
276// CeilDivUOp
277//===----------------------------------------------------------------------===//
278
279OpFoldResult CeilDivUOp::fold(FoldAdaptor adaptor) {
280 // Compute `ceildivu(n, m)` as `n == 0 ? 0 : (n-1)/m + 1`.
281 return foldBinaryOpChecked(
282 adaptor.getOperands(),
283 [](const APInt &n, const APInt &m) -> std::optional<APInt> {
284 // Don't fold division by zero.
285 if (m.isZero())
286 return std::nullopt;
287 // Short-circuit the zero case.
288 if (n.isZero())
289 return n;
290
291 return (n - 1).udiv(m) + 1;
292 });
293}
294
295//===----------------------------------------------------------------------===//
296// FloorDivSOp
297//===----------------------------------------------------------------------===//
298
299/// Compute `floordivs(n, m)` as `x = m < 0 ? 1 : -1` and then
300/// `n*m < 0 ? -1 - (x-n)/m : n/m`.
301static std::optional<APInt> calculateFloorDivS(const APInt &n, const APInt &m) {
302 // Don't fold division by zero.
303 if (m.isZero())
304 return std::nullopt;
305 // Short-circuit the zero case.
306 if (n.isZero())
307 return n;
308
309 bool mLtZ = m.slt(0);
310 if (n.slt(0) == mLtZ) {
311 // If the operands have the same sign, compute the positive result.
312 return n.sdiv(m);
313 }
314 // If the operands have different signs, compute the negative result. Signed
315 // division overflow is not possible since if `m == -1`, `x` will be 1 and
316 // `n` can be at most `INT_MAX`.
317 int64_t x = mLtZ ? 1 : -1;
318 return -1 - (x - n).sdiv(m);
319}
320
321OpFoldResult FloorDivSOp::fold(FoldAdaptor adaptor) {
322 return foldBinaryOpChecked(adaptor.getOperands(), calculateFloorDivS);
323}
324
325//===----------------------------------------------------------------------===//
326// RemSOp
327//===----------------------------------------------------------------------===//
328
329OpFoldResult RemSOp::fold(FoldAdaptor adaptor) {
330 return foldBinaryOpChecked(
331 adaptor.getOperands(),
332 [](const APInt &lhs, const APInt &rhs) -> std::optional<APInt> {
333 // Don't fold division by zero.
334 if (rhs.isZero())
335 return std::nullopt;
336 return lhs.srem(rhs);
337 });
338}
339
340//===----------------------------------------------------------------------===//
341// RemUOp
342//===----------------------------------------------------------------------===//
343
344OpFoldResult RemUOp::fold(FoldAdaptor adaptor) {
345 return foldBinaryOpChecked(
346 adaptor.getOperands(),
347 [](const APInt &lhs, const APInt &rhs) -> std::optional<APInt> {
348 // Don't fold division by zero.
349 if (rhs.isZero())
350 return std::nullopt;
351 return lhs.urem(rhs);
352 });
353}
354
355//===----------------------------------------------------------------------===//
356// MaxSOp
357//===----------------------------------------------------------------------===//
358
359OpFoldResult MaxSOp::fold(FoldAdaptor adaptor) {
360 return foldBinaryOpChecked(adaptor.getOperands(),
361 [](const APInt &lhs, const APInt &rhs) {
362 return lhs.sgt(rhs) ? lhs : rhs;
363 });
364}
365
366LogicalResult MaxSOp::canonicalize(MaxSOp op, PatternRewriter &rewriter) {
368}
369
370//===----------------------------------------------------------------------===//
371// MaxUOp
372//===----------------------------------------------------------------------===//
373
374OpFoldResult MaxUOp::fold(FoldAdaptor adaptor) {
375 return foldBinaryOpChecked(adaptor.getOperands(),
376 [](const APInt &lhs, const APInt &rhs) {
377 return lhs.ugt(rhs) ? lhs : rhs;
378 });
379}
380
381LogicalResult MaxUOp::canonicalize(MaxUOp op, PatternRewriter &rewriter) {
383}
384
385//===----------------------------------------------------------------------===//
386// MinSOp
387//===----------------------------------------------------------------------===//
388
389OpFoldResult MinSOp::fold(FoldAdaptor adaptor) {
390 return foldBinaryOpChecked(adaptor.getOperands(),
391 [](const APInt &lhs, const APInt &rhs) {
392 return lhs.slt(rhs) ? lhs : rhs;
393 });
394}
395
396LogicalResult MinSOp::canonicalize(MinSOp op, PatternRewriter &rewriter) {
398}
399
400//===----------------------------------------------------------------------===//
401// MinUOp
402//===----------------------------------------------------------------------===//
403
404OpFoldResult MinUOp::fold(FoldAdaptor adaptor) {
405 return foldBinaryOpChecked(adaptor.getOperands(),
406 [](const APInt &lhs, const APInt &rhs) {
407 return lhs.ult(rhs) ? lhs : rhs;
408 });
409}
410
411LogicalResult MinUOp::canonicalize(MinUOp op, PatternRewriter &rewriter) {
413}
414
415//===----------------------------------------------------------------------===//
416// ShlOp
417//===----------------------------------------------------------------------===//
418
419OpFoldResult ShlOp::fold(FoldAdaptor adaptor) {
421 adaptor.getOperands(),
422 [](const APInt &lhs, const APInt &rhs) -> std::optional<APInt> {
423 // We cannot fold if the RHS is greater than or equal to 32 because
424 // this would be UB in 32-bit systems but not on 64-bit systems. RHS is
425 // already treated as unsigned.
426 if (rhs.uge(32))
427 return {};
428 return lhs << rhs;
429 });
430}
431
432//===----------------------------------------------------------------------===//
433// ShrSOp
434//===----------------------------------------------------------------------===//
435
436OpFoldResult ShrSOp::fold(FoldAdaptor adaptor) {
437 return foldBinaryOpChecked(
438 adaptor.getOperands(),
439 [](const APInt &lhs, const APInt &rhs) -> std::optional<APInt> {
440 // Don't fold if RHS is greater than or equal to 32.
441 if (rhs.uge(32))
442 return {};
443 return lhs.ashr(rhs);
444 });
445}
446
447//===----------------------------------------------------------------------===//
448// ShrUOp
449//===----------------------------------------------------------------------===//
450
451OpFoldResult ShrUOp::fold(FoldAdaptor adaptor) {
452 return foldBinaryOpChecked(
453 adaptor.getOperands(),
454 [](const APInt &lhs, const APInt &rhs) -> std::optional<APInt> {
455 // Don't fold if RHS is greater than or equal to 32.
456 if (rhs.uge(32))
457 return {};
458 return lhs.lshr(rhs);
459 });
460}
461
462//===----------------------------------------------------------------------===//
463// AndOp
464//===----------------------------------------------------------------------===//
465
466OpFoldResult AndOp::fold(FoldAdaptor adaptor) {
468 adaptor.getOperands(),
469 [](const APInt &lhs, const APInt &rhs) { return lhs & rhs; });
470}
471
472LogicalResult AndOp::canonicalize(AndOp op, PatternRewriter &rewriter) {
474}
475
476//===----------------------------------------------------------------------===//
477// OrOp
478//===----------------------------------------------------------------------===//
479
480OpFoldResult OrOp::fold(FoldAdaptor adaptor) {
482 adaptor.getOperands(),
483 [](const APInt &lhs, const APInt &rhs) { return lhs | rhs; });
484}
485
486LogicalResult OrOp::canonicalize(OrOp op, PatternRewriter &rewriter) {
488}
489
490//===----------------------------------------------------------------------===//
491// XOrOp
492//===----------------------------------------------------------------------===//
493
494OpFoldResult XOrOp::fold(FoldAdaptor adaptor) {
496 adaptor.getOperands(),
497 [](const APInt &lhs, const APInt &rhs) { return lhs ^ rhs; });
498}
499
500LogicalResult XOrOp::canonicalize(XOrOp op, PatternRewriter &rewriter) {
502}
503
504//===----------------------------------------------------------------------===//
505// CastSOp
506//===----------------------------------------------------------------------===//
507
508static OpFoldResult
510 function_ref<APInt(const APInt &, unsigned)> extFn,
511 function_ref<APInt(const APInt &, unsigned)> extOrTruncFn) {
512 auto attr = dyn_cast_if_present<IntegerAttr>(input);
513 if (!attr)
514 return {};
515 const APInt &value = attr.getValue();
516
517 if (isa<IndexType>(type)) {
518 // When casting to an index type, perform the cast assuming a 64-bit target.
519 // The result can be truncated to 32 bits as needed and always be correct.
520 // This is because `cast32(cast64(value)) == cast32(value)`.
521 APInt result = extOrTruncFn(value, 64);
522 return IntegerAttr::get(type, result);
523 }
524
525 // When casting from an index type, we must ensure the results respect
526 // `cast_t(value) == cast_t(trunc32(value))`.
527 auto intType = cast<IntegerType>(type);
528 unsigned width = intType.getWidth();
529
530 // If the result type is at most 32 bits, then the cast can always be folded
531 // because it is always a truncation.
532 if (width <= 32) {
533 APInt result = value.trunc(width);
534 return IntegerAttr::get(type, result);
535 }
536
537 // If the result type is at least 64 bits, then the cast is always a
538 // extension. The results will differ if `trunc32(value) != value)`.
539 if (width >= 64) {
540 if (extFn(value.trunc(32), 64) != value)
541 return {};
542 APInt result = extFn(value, width);
543 return IntegerAttr::get(type, result);
544 }
545
546 // Otherwise, we just have to check the property directly.
547 APInt result = value.trunc(width);
548 if (result != extFn(value.trunc(32), width))
549 return {};
550 return IntegerAttr::get(type, result);
551}
552
553bool CastSOp::areCastCompatible(TypeRange lhsTypes, TypeRange rhsTypes) {
554 return llvm::isa<IndexType>(lhsTypes.front()) !=
555 llvm::isa<IndexType>(rhsTypes.front());
556}
557
558OpFoldResult CastSOp::fold(FoldAdaptor adaptor) {
559 return foldCastOp(
560 adaptor.getInput(), getType(),
561 [](const APInt &x, unsigned width) { return x.sext(width); },
562 [](const APInt &x, unsigned width) { return x.sextOrTrunc(width); });
563}
564
565//===----------------------------------------------------------------------===//
566// CastUOp
567//===----------------------------------------------------------------------===//
568
569bool CastUOp::areCastCompatible(TypeRange lhsTypes, TypeRange rhsTypes) {
570 return llvm::isa<IndexType>(lhsTypes.front()) !=
571 llvm::isa<IndexType>(rhsTypes.front());
572}
573
574OpFoldResult CastUOp::fold(FoldAdaptor adaptor) {
575 return foldCastOp(
576 adaptor.getInput(), getType(),
577 [](const APInt &x, unsigned width) { return x.zext(width); },
578 [](const APInt &x, unsigned width) { return x.zextOrTrunc(width); });
579}
580
581//===----------------------------------------------------------------------===//
582// CmpOp
583//===----------------------------------------------------------------------===//
584
585/// Compare two integers according to the comparison predicate.
586bool compareIndices(const APInt &lhs, const APInt &rhs,
587 IndexCmpPredicate pred) {
588 switch (pred) {
589 case IndexCmpPredicate::EQ:
590 return lhs.eq(rhs);
591 case IndexCmpPredicate::NE:
592 return lhs.ne(rhs);
593 case IndexCmpPredicate::SGE:
594 return lhs.sge(rhs);
595 case IndexCmpPredicate::SGT:
596 return lhs.sgt(rhs);
597 case IndexCmpPredicate::SLE:
598 return lhs.sle(rhs);
599 case IndexCmpPredicate::SLT:
600 return lhs.slt(rhs);
601 case IndexCmpPredicate::UGE:
602 return lhs.uge(rhs);
603 case IndexCmpPredicate::UGT:
604 return lhs.ugt(rhs);
605 case IndexCmpPredicate::ULE:
606 return lhs.ule(rhs);
607 case IndexCmpPredicate::ULT:
608 return lhs.ult(rhs);
609 }
610 llvm_unreachable("unhandled IndexCmpPredicate predicate");
611}
612
613/// `cmp(max/min(x, cstA), cstB)` can be folded to a constant depending on the
614/// values of `cstA` and `cstB`, the max or min operation, and the comparison
615/// predicate. Check whether the value folds in both 32-bit and 64-bit
616/// arithmetic and to the same value.
617static std::optional<bool> foldCmpOfMaxOrMin(Operation *lhsOp,
618 const APInt &cstA,
619 const APInt &cstB, unsigned width,
620 IndexCmpPredicate pred) {
622 .Case([&](MinSOp op) {
624 APInt::getSignedMinValue(width), cstA);
625 })
626 .Case([&](MinUOp op) {
628 APInt::getMinValue(width), cstA);
629 })
630 .Case([&](MaxSOp op) {
632 cstA, APInt::getSignedMaxValue(width));
633 })
634 .Case([&](MaxUOp op) {
636 cstA, APInt::getMaxValue(width));
637 });
638 return intrange::evaluatePred(static_cast<intrange::CmpPredicate>(pred),
639 lhsRange, ConstantIntRanges::constant(cstB));
640}
641
642/// Return the result of `cmp(pred, x, x)`
643static bool compareSameArgs(IndexCmpPredicate pred) {
644 switch (pred) {
645 case IndexCmpPredicate::EQ:
646 case IndexCmpPredicate::SGE:
647 case IndexCmpPredicate::SLE:
648 case IndexCmpPredicate::UGE:
649 case IndexCmpPredicate::ULE:
650 return true;
651 case IndexCmpPredicate::NE:
652 case IndexCmpPredicate::SGT:
653 case IndexCmpPredicate::SLT:
654 case IndexCmpPredicate::UGT:
655 case IndexCmpPredicate::ULT:
656 return false;
657 }
658 llvm_unreachable("unknown predicate in compareSameArgs");
659}
660
661OpFoldResult CmpOp::fold(FoldAdaptor adaptor) {
662 // Attempt to fold if both inputs are constant.
663 auto lhs = dyn_cast_if_present<IntegerAttr>(adaptor.getLhs());
664 auto rhs = dyn_cast_if_present<IntegerAttr>(adaptor.getRhs());
665 if (lhs && rhs) {
666 // Perform the comparison in 64-bit and 32-bit.
667 bool result64 = compareIndices(lhs.getValue(), rhs.getValue(), getPred());
668 bool result32 = compareIndices(lhs.getValue().trunc(32),
669 rhs.getValue().trunc(32), getPred());
670 if (result64 == result32)
671 return BoolAttr::get(getContext(), result64);
672 }
673
674 // Fold `cmp(max/min(x, cstA), cstB)`.
675 Operation *lhsOp = getLhs().getDefiningOp();
676 IntegerAttr cstA;
677 if (isa_and_nonnull<MinSOp, MinUOp, MaxSOp, MaxUOp>(lhsOp) &&
678 matchPattern(lhsOp->getOperand(1), m_Constant(&cstA)) && rhs) {
679 std::optional<bool> result64 = foldCmpOfMaxOrMin(
680 lhsOp, cstA.getValue(), rhs.getValue(), 64, getPred());
681 std::optional<bool> result32 =
682 foldCmpOfMaxOrMin(lhsOp, cstA.getValue().trunc(32),
683 rhs.getValue().trunc(32), 32, getPred());
684 // Fold if the 32-bit and 64-bit results are the same.
685 if (result64 && result32 && *result64 == *result32)
686 return BoolAttr::get(getContext(), *result64);
687 }
688
689 // Fold `cmp(x, x)`
690 if (getLhs() == getRhs())
691 return BoolAttr::get(getContext(), compareSameArgs(getPred()));
692
693 return {};
694}
695
696/// Canonicalize
697/// `x - y cmp 0` to `x cmp y`. or `x - y cmp 0` to `x cmp y`.
698/// `0 cmp x - y` to `y cmp x`. or `0 cmp x - y` to `y cmp x`.
699LogicalResult CmpOp::canonicalize(CmpOp op, PatternRewriter &rewriter) {
700 IntegerAttr cmpRhs;
701 IntegerAttr cmpLhs;
702
703 bool rhsIsZero = matchPattern(op.getRhs(), m_Constant(&cmpRhs)) &&
704 cmpRhs.getValue().isZero();
705 bool lhsIsZero = matchPattern(op.getLhs(), m_Constant(&cmpLhs)) &&
706 cmpLhs.getValue().isZero();
707 if (!rhsIsZero && !lhsIsZero)
708 return rewriter.notifyMatchFailure(op.getLoc(),
709 "cmp is not comparing something with 0");
710 SubOp subOp = rhsIsZero ? op.getLhs().getDefiningOp<index::SubOp>()
711 : op.getRhs().getDefiningOp<index::SubOp>();
712 if (!subOp)
713 return rewriter.notifyMatchFailure(
714 op.getLoc(), "non-zero operand is not a result of subtraction");
715
716 index::CmpOp newCmp;
717 if (rhsIsZero)
718 newCmp = index::CmpOp::create(rewriter, op.getLoc(), op.getPred(),
719 subOp.getLhs(), subOp.getRhs());
720 else
721 newCmp = index::CmpOp::create(rewriter, op.getLoc(), op.getPred(),
722 subOp.getRhs(), subOp.getLhs());
723 rewriter.replaceOp(op, newCmp);
724 return success();
725}
726
727//===----------------------------------------------------------------------===//
728// ConstantOp
729//===----------------------------------------------------------------------===//
730
731void ConstantOp::getAsmResultNames(
732 function_ref<void(Value, StringRef)> setNameFn) {
733 SmallString<32> specialNameBuffer;
734 llvm::raw_svector_ostream specialName(specialNameBuffer);
735 specialName << "idx" << getValueAttr().getValue();
736 setNameFn(getResult(), specialName.str());
737}
738
739OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) { return getValueAttr(); }
740
741void ConstantOp::build(OpBuilder &b, OperationState &state, int64_t value) {
742 build(b, state, b.getIndexType(), b.getIndexAttr(value));
743}
744
745//===----------------------------------------------------------------------===//
746// BoolConstantOp
747//===----------------------------------------------------------------------===//
748
749OpFoldResult BoolConstantOp::fold(FoldAdaptor adaptor) {
750 return getValueAttr();
751}
752
753void BoolConstantOp::getAsmResultNames(
754 function_ref<void(Value, StringRef)> setNameFn) {
755 setNameFn(getResult(), getValue() ? "true" : "false");
756}
757
758//===----------------------------------------------------------------------===//
759// ODS-Generated Definitions
760//===----------------------------------------------------------------------===//
761
762#define GET_OP_CLASSES
763#include "mlir/Dialect/Index/IR/IndexOps.cpp.inc"
return success()
lhs
static OpFoldResult foldBinaryOpUnchecked(ArrayRef< Attribute > operands, function_ref< std::optional< APInt >(const APInt &, const APInt &)> calculate)
Fold an index operation irrespective of the target bitwidth.
Definition IndexOps.cpp:68
static std::optional< APInt > calculateFloorDivS(const APInt &n, const APInt &m)
Compute floordivs(n, m) as x = m < 0 ?
Definition IndexOps.cpp:301
static std::optional< APInt > calculateCeilDivS(const APInt &n, const APInt &m)
Compute ceildivs(n, m) as x = m > 0 ?
Definition IndexOps.cpp:250
static std::optional< bool > foldCmpOfMaxOrMin(Operation *lhsOp, const APInt &cstA, const APInt &cstB, unsigned width, IndexCmpPredicate pred)
cmp(max/min(x, cstA), cstB) can be folded to a constant depending on the values of cstA and cstB,...
Definition IndexOps.cpp:617
LogicalResult canonicalizeAssociativeCommutativeBinaryOp(BinaryOp op, PatternRewriter &rewriter)
Helper for associative and commutative binary ops that can be transformed: x = op(v,...
Definition IndexOps.cpp:125
bool compareIndices(const APInt &lhs, const APInt &rhs, IndexCmpPredicate pred)
Compare two integers according to the comparison predicate.
Definition IndexOps.cpp:586
static OpFoldResult foldBinaryOpChecked(ArrayRef< Attribute > operands, function_ref< std::optional< APInt >(const APInt &, const APInt &lhs)> calculate)
Fold an index operation only if the truncated 64-bit result matches the 32-bit result for operations ...
Definition IndexOps.cpp:94
static OpFoldResult foldCastOp(Attribute input, Type type, function_ref< APInt(const APInt &, unsigned)> extFn, function_ref< APInt(const APInt &, unsigned)> extOrTruncFn)
Definition IndexOps.cpp:509
static bool compareSameArgs(IndexCmpPredicate pred)
Return the result of cmp(pred, x, x)
Definition IndexOps.cpp:643
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
if(!isCopyOut)
b getContext())
Attributes are known-constant values of operations.
Definition Attributes.h:25
static BoolAttr get(MLIRContext *context, bool value)
A set of arbitrary-precision integers representing bounds on a given integer value.
static ConstantIntRanges constant(const APInt &value)
Create a ConstantIntRanges with a constant value - that is, with the bounds [value,...
static ConstantIntRanges fromUnsigned(const APInt &umin, const APInt &umax)
Create an ConstantIntRanges with the unsigned minimum and maximum equal to umin and umax and the sign...
static ConstantIntRanges fromSigned(const APInt &smin, const APInt &smax)
Create an ConstantIntRanges with the signed minimum and maximum equal to smin and smax,...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
This class helps build Operations.
Definition Builders.h:207
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition Builders.h:526
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Value getOperand(unsigned idx)
Definition Operation.h:350
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isSignlessInteger() const
Return true if this is a signless integer type (with the specified width).
Definition Types.cpp:64
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
std::optional< bool > evaluatePred(CmpPredicate pred, const ConstantIntRanges &lhs, const ConstantIntRanges &rhs)
Returns a boolean value if pred is statically true or false for anypossible inputs falling within lhs...
CmpPredicate
Copy of the enum from arith and index to allow the common integer range infrastructure to not depend ...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:144
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152
This represents an operation in an abstracted form, suitable for use with the builder APIs.