MLIR 23.0.0git
InferIntDivisibilityOpInterfaceImpl.cpp
Go to the documentation of this file.
1//===- InferIntDivisibilityOpInterfaceImpl.cpp ----------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Direct implementations of `InferIntDivisibilityOpInterface` for arith ops.
10//
11//===----------------------------------------------------------------------===//
12
14#include "mlir/IR/Matchers.h"
16
17#include <cstdlib>
18
19using namespace mlir;
20using namespace mlir::arith;
21
24 if (!divisibility.isUninitialized())
25 return divisibility.getValue();
26 APInt intVal;
27 if (matchPattern(v, m_ConstantInt(&intVal))) {
28 uint64_t udiv = intVal.getZExtValue();
29 uint64_t sdiv = std::abs(intVal.getSExtValue());
30 return ConstantIntDivisibility(udiv, sdiv);
31 }
32 return ConstantIntDivisibility(1, 1);
33}
34
35// Result divisibility is the GCD (union) of the operand divisibilities.
36template <typename OpTy>
37static void
39 SetIntDivisibilityFn setResultDivs) {
40 auto lhsDiv = getDivisibilityOfOperand(op.getLhs(), argDivs[0]);
41 auto rhsDiv = getDivisibilityOfOperand(op.getRhs(), argDivs[1]);
42 setResultDivs(op.getResult(), lhsDiv.getUnion(rhsDiv));
43}
44
45void ConstantOp::inferResultDivisibility(ArrayRef<IntegerDivisibility> argDivs,
46 SetIntDivisibilityFn setResultDivs) {
47 auto constAttr = dyn_cast_if_present<IntegerAttr>(getValue());
48 if (!constAttr)
49 return;
50 const APInt &value = constAttr.getValue();
51 uint64_t udiv = value.getZExtValue();
52 uint64_t sdiv = std::abs(value.getSExtValue());
53 setResultDivs(getResult(), ConstantIntDivisibility(udiv, sdiv));
54}
55
56void AddIOp::inferResultDivisibility(ArrayRef<IntegerDivisibility> argDivs,
57 SetIntDivisibilityFn setResultDivs) {
58 inferBinaryGCDResultDivisibility(*this, argDivs, setResultDivs);
59}
60
61void SubIOp::inferResultDivisibility(ArrayRef<IntegerDivisibility> argDivs,
62 SetIntDivisibilityFn setResultDivs) {
63 inferBinaryGCDResultDivisibility(*this, argDivs, setResultDivs);
64}
65
66void MinUIOp::inferResultDivisibility(ArrayRef<IntegerDivisibility> argDivs,
67 SetIntDivisibilityFn setResultDivs) {
68 inferBinaryGCDResultDivisibility(*this, argDivs, setResultDivs);
69}
70
71void MaxUIOp::inferResultDivisibility(ArrayRef<IntegerDivisibility> argDivs,
72 SetIntDivisibilityFn setResultDivs) {
73 inferBinaryGCDResultDivisibility(*this, argDivs, setResultDivs);
74}
75
76void MinSIOp::inferResultDivisibility(ArrayRef<IntegerDivisibility> argDivs,
77 SetIntDivisibilityFn setResultDivs) {
78 inferBinaryGCDResultDivisibility(*this, argDivs, setResultDivs);
79}
80
81void MaxSIOp::inferResultDivisibility(ArrayRef<IntegerDivisibility> argDivs,
82 SetIntDivisibilityFn setResultDivs) {
83 inferBinaryGCDResultDivisibility(*this, argDivs, setResultDivs);
84}
85
86void MulIOp::inferResultDivisibility(ArrayRef<IntegerDivisibility> argDivs,
87 SetIntDivisibilityFn setResultDivs) {
88 auto lhsDivisibility = getDivisibilityOfOperand(getLhs(), argDivs[0]);
89 auto rhsDivisibility = getDivisibilityOfOperand(getRhs(), argDivs[1]);
90
91 uint64_t mulUDiv = lhsDivisibility.udiv() * rhsDivisibility.udiv();
92 uint64_t mulSDiv = lhsDivisibility.sdiv() * rhsDivisibility.sdiv();
93
94 setResultDivs(getResult(), ConstantIntDivisibility(mulUDiv, mulSDiv));
95}
96
97void DivUIOp::inferResultDivisibility(ArrayRef<IntegerDivisibility> argDivs,
98 SetIntDivisibilityFn setResultDivs) {
99 APInt intVal;
100 if (!matchPattern(getRhs(), m_ConstantInt(&intVal)))
101 return;
102
103 auto lhsDivisibility = getDivisibilityOfOperand(getLhs(), argDivs[0]);
104
105 uint64_t divUDiv = lhsDivisibility.udiv() % intVal.getZExtValue() == 0
106 ? lhsDivisibility.udiv() / intVal.getZExtValue()
107 : 1;
108 uint64_t divSDiv =
109 lhsDivisibility.sdiv() % std::abs(intVal.getSExtValue()) == 0
110 ? lhsDivisibility.sdiv() / std::abs(intVal.getSExtValue())
111 : 1;
112
113 setResultDivs(getResult(), ConstantIntDivisibility(divUDiv, divSDiv));
114}
115
116void SelectOp::inferResultDivisibility(ArrayRef<IntegerDivisibility> argDivs,
117 SetIntDivisibilityFn setResultDivs) {
118 // argDivs[0] is the condition (i1), argDivs[1] is true, argDivs[2] is false.
119 auto trueDiv = getDivisibilityOfOperand(getTrueValue(), argDivs[1]);
120 auto falseDiv = getDivisibilityOfOperand(getFalseValue(), argDivs[2]);
121 setResultDivs(getResult(), trueDiv.getUnion(falseDiv));
122}
static ConstantIntDivisibility getDivisibilityOfOperand(Value v, IntegerDivisibility divisibility)
static void inferBinaryGCDResultDivisibility(OpTy op, ArrayRef< IntegerDivisibility > argDivs, SetIntDivisibilityFn setResultDivs)
Statically known divisibility information for an integer SSA value.
This lattice value represents the integer divisibility of an SSA value.
const ConstantIntDivisibility & getValue() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition Matchers.h:527
llvm::function_ref< void(Value, const ConstantIntDivisibility &)> SetIntDivisibilityFn
The type of the setResultDivs callback provided to ops implementing InferIntDivisibilityOpInterface.