MLIR 22.0.0git
PWMAFunction.cpp
Go to the documentation of this file.
1//===- PWMAFunction.cpp - MLIR PWMAFunction Class -------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
14#include "llvm/ADT/STLExtras.h"
15#include "llvm/ADT/STLFunctionalExtras.h"
16#include "llvm/ADT/SmallVector.h"
17#include "llvm/Support/raw_ostream.h"
18#include <algorithm>
19#include <cassert>
20#include <optional>
21
22using namespace mlir;
23using namespace presburger;
24
25void MultiAffineFunction::assertIsConsistent() const {
26 assert(space.getNumVars() - space.getNumRangeVars() + 1 ==
27 output.getNumColumns() &&
28 "Inconsistent number of output columns");
29 assert(space.getNumDomainVars() + space.getNumSymbolVars() ==
30 divs.getNumNonDivs() &&
31 "Inconsistent number of non-division variables in divs");
32 assert(space.getNumRangeVars() == output.getNumRows() &&
33 "Inconsistent number of output rows");
34 assert(space.getNumLocalVars() == divs.getNumDivs() &&
35 "Inconsistent number of divisions.");
36 assert(divs.hasAllReprs() && "All divisions should have a representation");
37}
38
39// Return the result of subtracting the two given vectors pointwise.
40// The vectors must be of the same size.
41// e.g., [3, 4, 6] - [2, 5, 1] = [1, -1, 5].
44 assert(vecA.size() == vecB.size() &&
45 "Cannot subtract vectors of differing lengths!");
47 result.reserve(vecA.size());
48 for (unsigned i = 0, e = vecA.size(); i < e; ++i)
49 result.emplace_back(vecA[i] - vecB[i]);
50 return result;
51}
52
55 for (const Piece &piece : pieces)
56 domain.unionInPlace(piece.domain);
57 return domain;
58}
59
61 space.print(os);
62 os << "Division Representation:\n";
63 divs.print(os);
64 os << "Output:\n";
65 output.print(os);
66}
67
68void MultiAffineFunction::dump() const { print(llvm::errs()); }
69
72 assert(point.size() == getNumDomainVars() + getNumSymbolVars() &&
73 "Point has incorrect dimensionality!");
74
75 SmallVector<DynamicAPInt, 8> pointHomogenous{llvm::to_vector(point)};
76 // Get the division values at this point.
78 divs.divValuesAt(point);
79 // The given point didn't include the values of the divs which the output is a
80 // function of; we have computed one possible set of values and use them here.
81 pointHomogenous.reserve(pointHomogenous.size() + divValues.size());
82 for (const std::optional<DynamicAPInt> &divVal : divValues)
83 pointHomogenous.emplace_back(*divVal);
84 // The matrix `output` has an affine expression in the ith row, corresponding
85 // to the expression for the ith value in the output vector. The last column
86 // of the matrix contains the constant term. Let v be the input point with
87 // a 1 appended at the end. We can see that output * v gives the desired
88 // output vector.
89 pointHomogenous.emplace_back(1);
91 output.postMultiplyWithColumn(pointHomogenous);
92 assert(result.size() == getNumOutputs());
93 return result;
94}
95
97 assert(space.isCompatible(other.space) &&
98 "Spaces should be compatible for equality check.");
99 return getAsRelation().isEqual(other.getAsRelation());
100}
101
103 const IntegerPolyhedron &domain) const {
104 assert(space.isCompatible(other.space) &&
105 "Spaces should be compatible for equality check.");
106 IntegerRelation restrictedThis = getAsRelation();
107 restrictedThis.intersectDomain(domain);
108
109 IntegerRelation restrictedOther = other.getAsRelation();
110 restrictedOther.intersectDomain(domain);
111
112 return restrictedThis.isEqual(restrictedOther);
113}
114
116 const PresburgerSet &domain) const {
117 assert(space.isCompatible(other.space) &&
118 "Spaces should be compatible for equality check.");
119 return llvm::all_of(domain.getAllDisjuncts(),
120 [&](const IntegerRelation &disjunct) {
121 return isEqual(other, IntegerPolyhedron(disjunct));
122 });
123}
124
125void MultiAffineFunction::removeOutputs(unsigned start, unsigned end) {
126 assert(end <= getNumOutputs() && "Invalid range");
127
128 if (start >= end)
129 return;
130
131 space.removeVarRange(VarKind::Range, start, end);
132 output.removeRows(start, end - start);
133}
134
136 assert(space.isCompatible(other.space) && "Functions should be compatible");
137
138 unsigned nDivs = getNumDivs();
139 unsigned divOffset = divs.getDivOffset();
140
141 other.divs.insertDiv(0, nDivs);
142
144 for (unsigned i = 0; i < nDivs; ++i) {
145 // Zero fill.
146 llvm::fill(div, 0);
147 // Fill div with dividend from `divs`. Do not fill the constant.
148 std::copy(divs.getDividend(i).begin(), divs.getDividend(i).end() - 1,
149 div.begin());
150 // Fill constant.
151 div.back() = divs.getDividend(i).back();
152 other.divs.setDiv(i, div, divs.getDenom(i));
153 }
154
155 other.space.insertVar(VarKind::Local, 0, nDivs);
156 other.output.insertColumns(divOffset, nDivs);
157
158 auto merge = [&](unsigned i, unsigned j) {
159 // We only merge from local at pos j to local at pos i, where j > i.
160 if (i >= j)
161 return false;
162
163 // If i < nDivs, we are trying to merge duplicate divs in `this`. Since we
164 // do not want to merge duplicates in `this`, we ignore this call.
165 if (j < nDivs)
166 return false;
167
168 // Merge things in space and output.
169 other.space.removeVarRange(VarKind::Local, j, j + 1);
170 other.output.addToColumn(divOffset + i, divOffset + j, 1);
171 other.output.removeColumn(divOffset + j);
172 return true;
173 };
174
175 other.divs.removeDuplicateDivs(merge);
176
177 unsigned newDivs = other.divs.getNumDivs() - nDivs;
178
179 space.insertVar(VarKind::Local, nDivs, newDivs);
180 output.insertColumns(divOffset + nDivs, newDivs);
181 divs = other.divs;
182
183 // Check consistency.
184 assertIsConsistent();
185 other.assertIsConsistent();
186}
187
190 const MultiAffineFunction &other) const {
191 assert(getSpace().isCompatible(other.getSpace()) &&
192 "Output space of funcs should be compatible");
193
194 // Create copies of functions and merge their local space.
195 MultiAffineFunction funcA = *this;
196 MultiAffineFunction funcB = other;
197 funcA.mergeDivs(funcB);
198
199 // We first create the set `result`, corresponding to the set where output
200 // of funcA is lexicographically larger/smaller than funcB. This is done by
201 // creating a PresburgerSet with the following constraints:
202 //
203 // (outA[0] > outB[0]) U
204 // (outA[0] = outB[0], outA[1] > outA[1]) U
205 // (outA[0] = outB[0], outA[1] = outA[1], outA[2] > outA[2]) U
206 // ...
207 // (outA[0] = outB[0], ..., outA[n-2] = outB[n-2], outA[n-1] > outB[n-1])
208 //
209 // where `n` is the number of outputs.
210 // If `lexMin` is set, the complement inequality is used:
211 //
212 // (outA[0] < outB[0]) U
213 // (outA[0] = outB[0], outA[1] < outA[1]) U
214 // (outA[0] = outB[0], outA[1] = outA[1], outA[2] < outA[2]) U
215 // ...
216 // (outA[0] = outB[0], ..., outA[n-2] = outB[n-2], outA[n-1] < outB[n-1])
217 PresburgerSpace resultSpace = funcA.getDomainSpace();
220 IntegerPolyhedron levelSet(
221 /*numReservedInequalities=*/1 + 2 * resultSpace.getNumLocalVars(),
222 /*numReservedEqualities=*/funcA.getNumOutputs(),
223 /*numReservedCols=*/resultSpace.getNumVars() + 1, resultSpace);
224
225 // Add division inequalities to `levelSet`.
226 for (unsigned i = 0, e = funcA.getNumDivs(); i < e; ++i) {
227 levelSet.addInequality(getDivUpperBound(funcA.divs.getDividend(i),
228 funcA.divs.getDenom(i),
229 funcA.divs.getDivOffset() + i));
230 levelSet.addInequality(getDivLowerBound(funcA.divs.getDividend(i),
231 funcA.divs.getDenom(i),
232 funcA.divs.getDivOffset() + i));
233 }
234
235 for (unsigned level = 0; level < funcA.getNumOutputs(); ++level) {
236 // Create the expression `outA - outB` for this level.
238 subtractExprs(funcA.getOutputExpr(level), funcB.getOutputExpr(level));
239
240 // TODO: Implement all comparison cases.
241 switch (comp) {
242 case OrderingKind::LT:
243 // For less than, we add an upper bound of -1:
244 // outA - outB <= -1
245 // outA <= outB - 1
246 // outA < outB
247 levelSet.addBound(BoundType::UB, subExpr, DynamicAPInt(-1));
248 break;
249 case OrderingKind::GT:
250 // For greater than, we add a lower bound of 1:
251 // outA - outB >= 1
252 // outA > outB + 1
253 // outA > outB
254 levelSet.addBound(BoundType::LB, subExpr, DynamicAPInt(1));
255 break;
256 case OrderingKind::GE:
257 case OrderingKind::LE:
258 case OrderingKind::EQ:
259 case OrderingKind::NE:
260 assert(false && "Not implemented case");
261 }
262
263 // Union the set with the result.
264 result.unionInPlace(levelSet);
265 // The last inequality in `levelSet` is the bound we inserted. We remove
266 // that for next iteration.
267 levelSet.removeInequality(levelSet.getNumInequalities() - 1);
268 // Add equality `outA - outB == 0` for this level for next iteration.
269 levelSet.addEquality(subExpr);
270 }
271
272 return result;
273}
274
275/// Two PWMAFunctions are equal if they have the same dimensionalities,
276/// the same domain, and take the same value at every point in the domain.
277bool PWMAFunction::isEqual(const PWMAFunction &other) const {
278 if (!space.isCompatible(other.space))
279 return false;
280
281 if (!this->getDomain().isEqual(other.getDomain()))
282 return false;
283
284 // Check if, whenever the domains of a piece of `this` and a piece of `other`
285 // overlap, they take the same output value. If `this` and `other` have the
286 // same domain (checked above), then this check passes iff the two functions
287 // have the same output at every point in the domain.
288 return llvm::all_of(this->pieces, [&other](const Piece &pieceA) {
289 return llvm::all_of(other.pieces, [&pieceA](const Piece &pieceB) {
290 PresburgerSet commonDomain = pieceA.domain.intersect(pieceB.domain);
291 return pieceA.output.isEqual(pieceB.output, commonDomain);
292 });
293 });
294}
295
296void PWMAFunction::addPiece(const Piece &piece) {
297 assert(piece.isConsistent() && "Piece should be consistent");
298 assert(piece.domain.intersect(getDomain()).isIntegerEmpty() &&
299 "Piece should be disjoint from the function");
300 pieces.emplace_back(piece);
301}
302
304 space.print(os);
305 os << getNumPieces() << " pieces:\n";
306 for (const Piece &piece : pieces) {
307 os << "Domain of piece:\n";
308 piece.domain.print(os);
309 os << "Output of piece\n";
310 piece.output.print(os);
311 }
312}
313
314void PWMAFunction::dump() const { print(llvm::errs()); }
315
316PWMAFunction PWMAFunction::unionFunction(
317 const PWMAFunction &func,
318 llvm::function_ref<PresburgerSet(Piece maf1, Piece maf2)> tiebreak) const {
319 assert(getNumOutputs() == func.getNumOutputs() &&
320 "Ranges of functions should be same.");
321 assert(getSpace().isCompatible(func.getSpace()) &&
322 "Space is not compatible.");
323
324 // The algorithm used here is as follows:
325 // - Add the output of pieceB for the part of the domain where both pieceA and
326 // pieceB are defined, and `tiebreak` chooses the output of pieceB.
327 // - Add the output of pieceA, where pieceB is not defined or `tiebreak`
328 // chooses
329 // pieceA over pieceB.
330 // - Add the output of pieceB, where pieceA is not defined.
331
332 // Add parts of the common domain where pieceB's output is used. Also
333 // add all the parts where pieceA's output is used, both common and
334 // non-common.
336 for (const Piece &pieceA : pieces) {
337 PresburgerSet dom(pieceA.domain);
338 for (const Piece &pieceB : func.pieces) {
339 PresburgerSet better = tiebreak(pieceB, pieceA);
340 // Add the output of pieceB, where it is better than output of pieceA.
341 // The disjuncts in "better" will be disjoint as tiebreak should gurantee
342 // that.
343 result.addPiece({better, pieceB.output});
344 dom = dom.subtract(better);
345 }
346 // Add output of pieceA, where it is better than pieceB, or pieceB is not
347 // defined.
348 //
349 // `dom` here is guranteed to be disjoint from already added pieces
350 // because the pieces added before are either:
351 // - Subsets of the domain of other MAFs in `this`, which are guranteed
352 // to be disjoint from `dom`, or
353 // - They are one of the pieces added for `pieceB`, and we have been
354 // subtracting all such pieces from `dom`, so `dom` is disjoint from those
355 // pieces as well.
356 result.addPiece({dom, pieceA.output});
357 }
358
359 // Add parts of pieceB which are not shared with pieceA.
360 PresburgerSet dom = getDomain();
361 for (const Piece &pieceB : func.pieces)
362 result.addPiece({pieceB.domain.subtract(dom), pieceB.output});
363
364 return result;
365}
366
367/// A tiebreak function which breaks ties by comparing the outputs
368/// lexicographically based on the given comparison operator.
369/// This is templated since it is passed as a lambda.
370template <OrderingKind comp>
372 const PWMAFunction::Piece &pieceB) {
373 PresburgerSet result = pieceA.output.getLexSet(comp, pieceB.output);
374 result = result.intersect(pieceA.domain).intersect(pieceB.domain);
375
376 return result;
377}
378
382
386
388 assert(space.isCompatible(other.space) &&
389 "Spaces should be compatible for subtraction.");
390
391 MultiAffineFunction copyOther = other;
392 mergeDivs(copyOther);
393 for (unsigned i = 0, e = getNumOutputs(); i < e; ++i)
394 output.addToRow(i, copyOther.getOutputExpr(i), DynamicAPInt(-1));
395
396 // Check consistency.
397 assertIsConsistent();
398}
399
400/// Adds division constraints corresponding to local variables, given a
401/// relation and division representations of the local variables in the
402/// relation.
404 const DivisionRepr &divs) {
405 assert(divs.hasAllReprs() &&
406 "All divisions in divs should have a representation");
407 assert(rel.getNumVars() == divs.getNumVars() &&
408 "Relation and divs should have the same number of vars");
409 assert(rel.getNumLocalVars() == divs.getNumDivs() &&
410 "Relation and divs should have the same number of local vars");
411
412 for (unsigned i = 0, e = divs.getNumDivs(); i < e; ++i) {
414 divs.getDivOffset() + i));
416 divs.getDivOffset() + i));
417 }
418}
419
421 // Create a relation corressponding to the input space plus the divisions
422 // used in outputs.
424 space.getNumDomainVars(), 0, space.getNumSymbolVars(),
425 space.getNumLocalVars()));
426 // Add division constraints corresponding to divisions used in outputs.
428 // The outputs are represented as range variables in the relation. We add
429 // range variables for the outputs.
430 result.insertVar(VarKind::Range, 0, getNumOutputs());
431
432 // Add equalities such that the i^th range variable is equal to the i^th
433 // output expression.
434 SmallVector<DynamicAPInt, 8> eq(result.getNumCols());
435 for (unsigned i = 0, e = getNumOutputs(); i < e; ++i) {
436 // TODO: Add functions to get VarKind offsets in output in MAF and use them
437 // here.
438 // The output expression does not contain range variables, while the
439 // equality does. So, we need to copy all variables and mark all range
440 // variables as 0 in the equality.
442 // Copy domain variables in `expr` to domain variables in `eq`.
443 std::copy(expr.begin(), expr.begin() + getNumDomainVars(), eq.begin());
444 // Fill the range variables in `eq` as zero.
445 std::fill(eq.begin() + result.getVarKindOffset(VarKind::Range),
446 eq.begin() + result.getVarKindEnd(VarKind::Range), 0);
447 // Copy remaining variables in `expr` to the remaining variables in `eq`.
448 std::copy(expr.begin() + getNumDomainVars(), expr.end(),
449 eq.begin() + result.getVarKindEnd(VarKind::Range));
450
451 // Set the i^th range var to -1 in `eq` to equate the output expression to
452 // this range var.
453 eq[result.getVarKindOffset(VarKind::Range) + i] = -1;
454 // Add the equality `rangeVar_i = output[i]`.
455 result.addEquality(eq);
456 }
457
458 return result;
459}
460
461void PWMAFunction::removeOutputs(unsigned start, unsigned end) {
462 space.removeVarRange(VarKind::Range, start, end);
463 for (Piece &piece : pieces)
464 piece.output.removeOutputs(start, end);
465}
466
467std::optional<SmallVector<DynamicAPInt, 8>>
469 assert(point.size() == getNumDomainVars() + getNumSymbolVars());
470
471 for (const Piece &piece : pieces)
472 if (piece.domain.containsPoint(point))
473 return piece.output.valueAt(point);
474 return std::nullopt;
475}
static PresburgerSet tiebreakLex(const PWMAFunction::Piece &pieceA, const PWMAFunction::Piece &pieceB)
A tiebreak function which breaks ties by comparing the outputs lexicographically based on the given c...
static SmallVector< DynamicAPInt, 8 > subtractExprs(ArrayRef< DynamicAPInt > vecA, ArrayRef< DynamicAPInt > vecB)
static void addDivisionConstraints(IntegerRelation &rel, const DivisionRepr &divs)
Adds division constraints corresponding to local variables, given a relation and division representat...
#define div(a, b)
Class storing division representation of local variables of a constraint system.
Definition Utils.h:117
void removeDuplicateDivs(llvm::function_ref< bool(unsigned i, unsigned j)> merge)
Removes duplicate divisions.
Definition Utils.cpp:440
unsigned getNumVars() const
Definition Utils.h:124
unsigned getDivOffset() const
Definition Utils.h:128
unsigned getNumDivs() const
Definition Utils.h:125
DynamicAPInt & getDenom(unsigned i)
Definition Utils.h:153
void insertDiv(unsigned pos, ArrayRef< DynamicAPInt > dividend, const DynamicAPInt &divisor)
Definition Utils.cpp:494
void setDiv(unsigned i, ArrayRef< DynamicAPInt > dividend, const DynamicAPInt &divisor)
Definition Utils.h:158
MutableArrayRef< DynamicAPInt > getDividend(unsigned i)
Definition Utils.h:139
An IntegerPolyhedron represents the set of points from a PresburgerSpace that satisfy a list of affin...
An IntegerRelation represents the set of points from a PresburgerSpace that satisfy a list of affine ...
void addBound(BoundType type, unsigned pos, const DynamicAPInt &value)
Adds a constant bound for the specified variable.
void intersectDomain(const IntegerPolyhedron &poly)
Intersect the given poly with the domain in-place.
bool isEqual(const IntegerRelation &other) const
Return whether this and other are equal.
void addEquality(ArrayRef< DynamicAPInt > eq)
Adds an equality from the coefficients specified in eq.
void addInequality(ArrayRef< DynamicAPInt > inEq)
Adds an inequality (>= 0) from the coefficients specified in inEq.
void removeColumn(unsigned pos)
Definition Matrix.cpp:194
void addToColumn(unsigned sourceColumn, unsigned targetColumn, const T &scale)
Add scale multiples of the source column to the target column.
Definition Matrix.cpp:319
void insertColumns(unsigned pos, unsigned count)
Insert columns having positions pos, pos + 1, ... pos + count - 1.
Definition Matrix.cpp:152
void subtract(const MultiAffineFunction &other)
void removeOutputs(unsigned start, unsigned end)
Remove the specified range of outputs.
const PresburgerSpace & getSpace() const
Get the space of this function.
void print(raw_ostream &os) const
MultiAffineFunction(const PresburgerSpace &space, const IntMatrix &output)
PresburgerSpace getDomainSpace() const
Get the domain/output space of the function.
PresburgerSet getLexSet(OrderingKind comp, const MultiAffineFunction &other) const
Return the set of domain points where the output of this and other are ordered lexicographically acco...
SmallVector< DynamicAPInt, 8 > valueAt(ArrayRef< DynamicAPInt > point) const
void mergeDivs(MultiAffineFunction &other)
Given a MAF other, merges division variables such that both functions have the union of the division ...
ArrayRef< DynamicAPInt > getOutputExpr(unsigned i) const
Get the i^th output expression.
IntegerRelation getAsRelation() const
Get this function as a relation.
bool isEqual(const MultiAffineFunction &other) const
Return whether the this and other are equal when the domain is restricted to domain.
This class represents a piece-wise MultiAffineFunction.
void addPiece(const Piece &piece)
unsigned getNumDomainVars() const
void print(raw_ostream &os) const
PWMAFunction unionLexMax(const PWMAFunction &func)
void removeOutputs(unsigned start, unsigned end)
Remove the specified range of outputs.
const PresburgerSpace & getSpace() const
PWMAFunction unionLexMin(const PWMAFunction &func)
Return a function defined on the union of the domains of this and func, such that when only one of th...
PWMAFunction(const PresburgerSpace &space)
std::optional< SmallVector< DynamicAPInt, 8 > > valueAt(ArrayRef< DynamicAPInt > point) const
Return the output of the function at the given point.
PresburgerSet getDomain() const
Return the domain of this piece-wise MultiAffineFunction.
PresburgerSpace getDomainSpace() const
Get the domain/output space of the function.
unsigned getNumSymbolVars() const
bool isEqual(const PWMAFunction &other) const
Return whether this and other are equal as PWMAFunctions, i.e.
void unionInPlace(const IntegerRelation &disjunct)
Mutate this set, turning it into the union of this set and the given disjunct.
ArrayRef< IntegerRelation > getAllDisjuncts() const
Return a reference to the list of disjuncts.
PresburgerSet intersect(const PresburgerRelation &set) const
static PresburgerSet getEmpty(const PresburgerSpace &space)
Return an empty set of the specified type that contains no points.
PresburgerSpace is the space of all possible values of a tuple of integer valued variables/variables.
void removeVarRange(VarKind kind, unsigned varStart, unsigned varLimit)
Removes variables of the specified kind in the column range [varStart, varLimit).
PresburgerSpace getSpaceWithoutLocals() const
Get the space without local variables.
static PresburgerSpace getRelationSpace(unsigned numDomain=0, unsigned numRange=0, unsigned numSymbols=0, unsigned numLocals=0)
unsigned insertVar(VarKind kind, unsigned pos, unsigned num=1)
Insert num variables of the specified kind at position pos.
OrderingKind
Enum representing a binary comparison operator: equal, not equal, less than, less than or equal,...
SmallVector< DynamicAPInt, 8 > getDivUpperBound(ArrayRef< DynamicAPInt > dividend, const DynamicAPInt &divisor, unsigned localVarIdx)
If q is defined to be equal to expr floordiv d, this equivalent to saying that q is an integer and q ...
Definition Utils.cpp:315
SmallVector< DynamicAPInt, 8 > getDivLowerBound(ArrayRef< DynamicAPInt > dividend, const DynamicAPInt &divisor, unsigned localVarIdx)
Definition Utils.cpp:327
Include the generated interface declarations.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.