MLIR 23.0.0git
PWMAFunction.cpp
Go to the documentation of this file.
1//===- PWMAFunction.cpp - MLIR PWMAFunction Class -------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
14#include "llvm/ADT/STLExtras.h"
15#include "llvm/ADT/STLFunctionalExtras.h"
16#include "llvm/ADT/SmallVector.h"
17#include "llvm/Support/raw_ostream.h"
18#include <algorithm>
19#include <cassert>
20#include <optional>
21
22using namespace mlir;
23using namespace presburger;
24
25void MultiAffineFunction::assertIsConsistent() const {
26 assert(space.getNumVars() - space.getNumRangeVars() + 1 ==
27 output.getNumColumns() &&
28 "Inconsistent number of output columns");
29 assert(space.getNumDomainVars() + space.getNumSymbolVars() ==
30 divs.getNumNonDivs() &&
31 "Inconsistent number of non-division variables in divs");
32 assert(space.getNumRangeVars() == output.getNumRows() &&
33 "Inconsistent number of output rows");
34 assert(space.getNumLocalVars() == divs.getNumDivs() &&
35 "Inconsistent number of divisions.");
36 assert(divs.hasAllReprs() && "All divisions should have a representation");
37}
38
39// Return the result of subtracting the two given vectors pointwise.
40// The vectors must be of the same size.
41// e.g., [3, 4, 6] - [2, 5, 1] = [1, -1, 5].
44 assert(vecA.size() == vecB.size() &&
45 "Cannot subtract vectors of differing lengths!");
47 result.reserve(vecA.size());
48 for (unsigned i = 0, e = vecA.size(); i < e; ++i)
49 result.emplace_back(vecA[i] - vecB[i]);
50 return result;
51}
52
55 for (const Piece &piece : pieces)
56 domain.unionInPlace(piece.domain);
57 return domain;
58}
59
61 space.print(os);
62 os << "Division Representation:\n";
63 divs.print(os);
64 os << "Output:\n";
65 output.print(os);
66}
67
68void MultiAffineFunction::dump() const { print(llvm::errs()); }
69
72 assert(point.size() == getNumDomainVars() + getNumSymbolVars() &&
73 "Point has incorrect dimensionality!");
74
75 SmallVector<DynamicAPInt, 8> pointHomogenous{llvm::to_vector(point)};
76 // Get the division values at this point.
78 divs.divValuesAt(point);
79 // The given point didn't include the values of the divs which the output is a
80 // function of; we have computed one possible set of values and use them here.
81 pointHomogenous.reserve(pointHomogenous.size() + divValues.size());
82 for (const std::optional<DynamicAPInt> &divVal : divValues)
83 pointHomogenous.emplace_back(*divVal);
84 // The matrix `output` has an affine expression in the ith row, corresponding
85 // to the expression for the ith value in the output vector. The last column
86 // of the matrix contains the constant term. Let v be the input point with
87 // a 1 appended at the end. We can see that output * v gives the desired
88 // output vector.
89 pointHomogenous.emplace_back(1);
91 output.postMultiplyWithColumn(pointHomogenous);
92 assert(result.size() == getNumOutputs());
93 return result;
94}
95
97 assert(space.isCompatible(other.space) &&
98 "Spaces should be compatible for equality check.");
99 return getAsRelation().isEqual(other.getAsRelation());
100}
101
103 const IntegerPolyhedron &domain) const {
104 assert(space.isCompatible(other.space) &&
105 "Spaces should be compatible for equality check.");
106 IntegerRelation restrictedThis = getAsRelation();
107 restrictedThis.intersectDomain(domain);
108
109 IntegerRelation restrictedOther = other.getAsRelation();
110 restrictedOther.intersectDomain(domain);
111
112 return restrictedThis.isEqual(restrictedOther);
113}
114
116 const PresburgerSet &domain) const {
117 assert(space.isCompatible(other.space) &&
118 "Spaces should be compatible for equality check.");
119 return llvm::all_of(domain.getAllDisjuncts(),
120 [&](const IntegerRelation &disjunct) {
121 return isEqual(other, IntegerPolyhedron(disjunct));
122 });
123}
124
125void MultiAffineFunction::removeOutputs(unsigned start, unsigned end) {
126 assert(end <= getNumOutputs() && "Invalid range");
127
128 if (start >= end)
129 return;
130
131 space.removeVarRange(VarKind::Range, start, end);
132 output.removeRows(start, end - start);
133}
134
136 assert(space.isCompatible(other.space) && "Functions should be compatible");
137
138 unsigned nDivs = getNumDivs();
139 unsigned divOffset = divs.getDivOffset();
140
141 other.divs.insertDiv(0, nDivs);
142
144 for (unsigned i = 0; i < nDivs; ++i) {
145 // Zero fill.
146 llvm::fill(div, 0);
147 // Fill div with dividend from `divs`. Do not fill the constant.
148 std::copy(divs.getDividend(i).begin(), divs.getDividend(i).end() - 1,
149 div.begin());
150 // Fill constant.
151 div.back() = divs.getDividend(i).back();
152 other.divs.setDiv(i, div, divs.getDenom(i));
153 }
154
155 other.space.insertVar(VarKind::Local, 0, nDivs);
156 other.output.insertColumns(divOffset, nDivs);
157
158 auto merge = [&](unsigned i, unsigned j) {
159 // We only merge from local at pos j to local at pos i, where j > i.
160 if (i >= j)
161 return false;
162
163 // If i < nDivs, we are trying to merge duplicate divs in `this`. Since we
164 // do not want to merge duplicates in `this`, we ignore this call.
165 if (j < nDivs)
166 return false;
167
168 // Merge things in space and output.
169 other.space.removeVarRange(VarKind::Local, j, j + 1);
170 other.output.addToColumn(divOffset + i, divOffset + j, 1);
171 other.output.removeColumn(divOffset + j);
172 return true;
173 };
174
175 other.divs.removeDuplicateDivs(merge);
176
177 unsigned newDivs = other.divs.getNumDivs() - nDivs;
178
179 space.insertVar(VarKind::Local, nDivs, newDivs);
180 output.insertColumns(divOffset + nDivs, newDivs);
181 divs = other.divs;
182
183 // Check consistency.
184 assertIsConsistent();
185 other.assertIsConsistent();
186}
187
190 const MultiAffineFunction &other) const {
191 assert(getSpace().isCompatible(other.getSpace()) &&
192 "Output space of funcs should be compatible");
193
194 // Create copies of functions and merge their local space.
195 MultiAffineFunction funcA = *this;
196 MultiAffineFunction funcB = other;
197 funcA.mergeDivs(funcB);
198
199 // We first create the set `result`, corresponding to the set where output
200 // of funcA is lexicographically larger/smaller than funcB. This is done by
201 // creating a PresburgerSet with the following constraints:
202 //
203 // (outA[0] > outB[0]) U
204 // (outA[0] = outB[0], outA[1] > outA[1]) U
205 // (outA[0] = outB[0], outA[1] = outA[1], outA[2] > outA[2]) U
206 // ...
207 // (outA[0] = outB[0], ..., outA[n-2] = outB[n-2], outA[n-1] > outB[n-1])
208 //
209 // where `n` is the number of outputs.
210 // If `lexMin` is set, the complement inequality is used:
211 //
212 // (outA[0] < outB[0]) U
213 // (outA[0] = outB[0], outA[1] < outA[1]) U
214 // (outA[0] = outB[0], outA[1] = outA[1], outA[2] < outA[2]) U
215 // ...
216 // (outA[0] = outB[0], ..., outA[n-2] = outB[n-2], outA[n-1] < outB[n-1])
217 PresburgerSpace resultSpace = funcA.getDomainSpace();
220 IntegerPolyhedron levelSet(
221 /*numReservedInequalities=*/1 + 2 * resultSpace.getNumLocalVars(),
222 /*numReservedEqualities=*/funcA.getNumOutputs(),
223 /*numReservedCols=*/resultSpace.getNumVars() + 1, resultSpace);
224
225 // Add division inequalities to `levelSet`.
226 for (unsigned i = 0, e = funcA.getNumDivs(); i < e; ++i) {
227 levelSet.addInequality(getDivUpperBound(funcA.divs.getDividend(i),
228 funcA.divs.getDenom(i),
229 funcA.divs.getDivOffset() + i));
230 levelSet.addInequality(getDivLowerBound(funcA.divs.getDividend(i),
231 funcA.divs.getDenom(i),
232 funcA.divs.getDivOffset() + i));
233 }
234
235 for (unsigned level = 0; level < funcA.getNumOutputs(); ++level) {
236 // Create the expression `outA - outB` for this level.
238 subtractExprs(funcA.getOutputExpr(level), funcB.getOutputExpr(level));
239
240 // TODO: Implement all comparison cases.
241 switch (comp) {
242 case OrderingKind::LT:
243 // For less than, we add an upper bound of -1:
244 // outA - outB <= -1
245 // outA <= outB - 1
246 // outA < outB
247 levelSet.addBound(BoundType::UB, subExpr, DynamicAPInt(-1));
248 break;
249 case OrderingKind::GT:
250 // For greater than, we add a lower bound of 1:
251 // outA - outB >= 1
252 // outA > outB + 1
253 // outA > outB
254 levelSet.addBound(BoundType::LB, subExpr, DynamicAPInt(1));
255 break;
256 case OrderingKind::GE:
257 case OrderingKind::LE:
258 case OrderingKind::EQ:
259 case OrderingKind::NE:
260 assert(false && "Not implemented case");
261 }
262
263 // Union the set with the result.
264 result.unionInPlace(levelSet);
265 // The last inequality in `levelSet` is the bound we inserted. We remove
266 // that for next iteration.
267 levelSet.removeInequality(levelSet.getNumInequalities() - 1);
268 // Add equality `outA - outB == 0` for this level for next iteration.
269 levelSet.addEquality(subExpr);
270 }
271
272 return result;
273}
274
275/// Two PWMAFunctions are equal if they have the same dimensionalities,
276/// the same domain, and take the same value at every point in the domain.
277bool PWMAFunction::isEqual(const PWMAFunction &other) const {
278 if (!space.isCompatible(other.space))
279 return false;
280
281 if (!this->getDomain().isEqual(other.getDomain()))
282 return false;
283
284 // Check if, whenever the domains of a piece of `this` and a piece of `other`
285 // overlap, they take the same output value. If `this` and `other` have the
286 // same domain (checked above), then this check passes iff the two functions
287 // have the same output at every point in the domain.
288 return llvm::all_of(this->pieces, [&other](const Piece &pieceA) {
289 return llvm::all_of(other.pieces, [&pieceA](const Piece &pieceB) {
290 PresburgerSet commonDomain = pieceA.domain.intersect(pieceB.domain);
291 return pieceA.output.isEqual(pieceB.output, commonDomain);
292 });
293 });
294}
295
296void PWMAFunction::addPiece(const Piece &piece) {
297 assert(piece.isConsistent() && "Piece should be consistent");
298 assert(piece.domain.intersect(getDomain()).isIntegerEmpty() &&
299 "Piece should be disjoint from the function");
300 pieces.emplace_back(piece);
301}
302
304 space.print(os);
305 os << getNumPieces() << " pieces:\n";
306 for (const Piece &piece : pieces) {
307 os << "Domain of piece:\n";
308 piece.domain.print(os);
309 os << "Output of piece\n";
310 piece.output.print(os);
311 }
312}
313
314void PWMAFunction::dump() const { print(llvm::errs()); }
315
316PWMAFunction PWMAFunction::unionFunction(
317 const PWMAFunction &func,
318 llvm::function_ref<PresburgerSet(Piece maf1, Piece maf2)> tiebreak) const {
319 assert(getNumOutputs() == func.getNumOutputs() &&
320 "Ranges of functions should be same.");
321 assert(getSpace().isCompatible(func.getSpace()) &&
322 "Space is not compatible.");
323
324 // The algorithm used here is as follows:
325 // - Add the output of pieceB for the part of the domain where both pieceA and
326 // pieceB are defined, and `tiebreak` chooses the output of pieceB.
327 // - Add the output of pieceA, where pieceB is not defined or `tiebreak`
328 // chooses
329 // pieceA over pieceB.
330 // - Add the output of pieceB, where pieceA is not defined.
331
332 // Add parts of the common domain where pieceB's output is used. Also
333 // add all the parts where pieceA's output is used, both common and
334 // non-common.
336 for (const Piece &pieceA : pieces) {
337 PresburgerSet dom(pieceA.domain);
338 for (const Piece &pieceB : func.pieces) {
339 PresburgerSet better = tiebreak(pieceB, pieceA);
340 if (better.isIntegerEmpty())
341 continue;
342
343 // Add the output of pieceB, where it is better than output of pieceA.
344 // The disjuncts in "better" will be disjoint as tiebreak should gurantee
345 // that.
346 result.addPiece({better, pieceB.output});
347 dom = dom.subtract(better);
348 }
349 // Add output of pieceA, where it is better than pieceB, or pieceB is not
350 // defined.
351 //
352 // `dom` here is guranteed to be disjoint from already added pieces
353 // because the pieces added before are either:
354 // - Subsets of the domain of other MAFs in `this`, which are guranteed
355 // to be disjoint from `dom`, or
356 // - They are one of the pieces added for `pieceB`, and we have been
357 // subtracting all such pieces from `dom`, so `dom` is disjoint from those
358 // pieces as well.
359 result.addPiece({dom, pieceA.output});
360 }
361
362 // Add parts of pieceB which are not shared with pieceA.
363 PresburgerSet dom = getDomain();
364 for (const Piece &pieceB : func.pieces)
365 result.addPiece({pieceB.domain.subtract(dom), pieceB.output});
366
367 return result;
368}
369
370/// A tiebreak function which breaks ties by comparing the outputs
371/// lexicographically based on the given comparison operator.
372/// This is templated since it is passed as a lambda.
373template <OrderingKind comp>
375 const PWMAFunction::Piece &pieceB) {
376 PresburgerSet result = pieceA.output.getLexSet(comp, pieceB.output);
377 result = result.intersect(pieceA.domain).intersect(pieceB.domain);
378
379 return result;
380}
381
385
389
391 assert(space.isCompatible(other.space) &&
392 "Spaces should be compatible for subtraction.");
393
394 MultiAffineFunction copyOther = other;
395 mergeDivs(copyOther);
396 for (unsigned i = 0, e = getNumOutputs(); i < e; ++i)
397 output.addToRow(i, copyOther.getOutputExpr(i), DynamicAPInt(-1));
398
399 // Check consistency.
400 assertIsConsistent();
401}
402
403/// Adds division constraints corresponding to local variables, given a
404/// relation and division representations of the local variables in the
405/// relation.
407 const DivisionRepr &divs) {
408 assert(divs.hasAllReprs() &&
409 "All divisions in divs should have a representation");
410 assert(rel.getNumVars() == divs.getNumVars() &&
411 "Relation and divs should have the same number of vars");
412 assert(rel.getNumLocalVars() == divs.getNumDivs() &&
413 "Relation and divs should have the same number of local vars");
414
415 for (unsigned i = 0, e = divs.getNumDivs(); i < e; ++i) {
417 divs.getDivOffset() + i));
419 divs.getDivOffset() + i));
420 }
421}
422
424 // Create a relation corressponding to the input space plus the divisions
425 // used in outputs.
427 space.getNumDomainVars(), 0, space.getNumSymbolVars(),
428 space.getNumLocalVars()));
429 // Add division constraints corresponding to divisions used in outputs.
431 // The outputs are represented as range variables in the relation. We add
432 // range variables for the outputs.
433 result.insertVar(VarKind::Range, 0, getNumOutputs());
434
435 // Add equalities such that the i^th range variable is equal to the i^th
436 // output expression.
437 SmallVector<DynamicAPInt, 8> eq(result.getNumCols());
438 for (unsigned i = 0, e = getNumOutputs(); i < e; ++i) {
439 // TODO: Add functions to get VarKind offsets in output in MAF and use them
440 // here.
441 // The output expression does not contain range variables, while the
442 // equality does. So, we need to copy all variables and mark all range
443 // variables as 0 in the equality.
445 // Copy domain variables in `expr` to domain variables in `eq`.
446 std::copy(expr.begin(), expr.begin() + getNumDomainVars(), eq.begin());
447 // Fill the range variables in `eq` as zero.
448 std::fill(eq.begin() + result.getVarKindOffset(VarKind::Range),
449 eq.begin() + result.getVarKindEnd(VarKind::Range), 0);
450 // Copy remaining variables in `expr` to the remaining variables in `eq`.
451 std::copy(expr.begin() + getNumDomainVars(), expr.end(),
452 eq.begin() + result.getVarKindEnd(VarKind::Range));
453
454 // Set the i^th range var to -1 in `eq` to equate the output expression to
455 // this range var.
456 eq[result.getVarKindOffset(VarKind::Range) + i] = -1;
457 // Add the equality `rangeVar_i = output[i]`.
458 result.addEquality(eq);
459 }
460
461 return result;
462}
463
464void PWMAFunction::removeOutputs(unsigned start, unsigned end) {
465 space.removeVarRange(VarKind::Range, start, end);
466 for (Piece &piece : pieces)
467 piece.output.removeOutputs(start, end);
468}
469
470std::optional<SmallVector<DynamicAPInt, 8>>
472 assert(point.size() == getNumDomainVars() + getNumSymbolVars());
473
474 for (const Piece &piece : pieces)
475 if (piece.domain.containsPoint(point))
476 return piece.output.valueAt(point);
477 return std::nullopt;
478}
static PresburgerSet tiebreakLex(const PWMAFunction::Piece &pieceA, const PWMAFunction::Piece &pieceB)
A tiebreak function which breaks ties by comparing the outputs lexicographically based on the given c...
static SmallVector< DynamicAPInt, 8 > subtractExprs(ArrayRef< DynamicAPInt > vecA, ArrayRef< DynamicAPInt > vecB)
static void addDivisionConstraints(IntegerRelation &rel, const DivisionRepr &divs)
Adds division constraints corresponding to local variables, given a relation and division representat...
#define div(a, b)
Class storing division representation of local variables of a constraint system.
Definition Utils.h:117
void removeDuplicateDivs(llvm::function_ref< bool(unsigned i, unsigned j)> merge)
Removes duplicate divisions.
Definition Utils.cpp:439
unsigned getNumVars() const
Definition Utils.h:124
unsigned getDivOffset() const
Definition Utils.h:128
unsigned getNumDivs() const
Definition Utils.h:125
DynamicAPInt & getDenom(unsigned i)
Definition Utils.h:153
void insertDiv(unsigned pos, ArrayRef< DynamicAPInt > dividend, const DynamicAPInt &divisor)
Definition Utils.cpp:493
void setDiv(unsigned i, ArrayRef< DynamicAPInt > dividend, const DynamicAPInt &divisor)
Definition Utils.h:158
MutableArrayRef< DynamicAPInt > getDividend(unsigned i)
Definition Utils.h:139
An IntegerPolyhedron represents the set of points from a PresburgerSpace that satisfy a list of affin...
An IntegerRelation represents the set of points from a PresburgerSpace that satisfy a list of affine ...
void addBound(BoundType type, unsigned pos, const DynamicAPInt &value)
Adds a constant bound for the specified variable.
void intersectDomain(const IntegerPolyhedron &poly)
Intersect the given poly with the domain in-place.
bool isEqual(const IntegerRelation &other) const
Return whether this and other are equal.
void addEquality(ArrayRef< DynamicAPInt > eq)
Adds an equality from the coefficients specified in eq.
void addInequality(ArrayRef< DynamicAPInt > inEq)
Adds an inequality (>= 0) from the coefficients specified in inEq.
void removeColumn(unsigned pos)
Definition Matrix.cpp:194
void addToColumn(unsigned sourceColumn, unsigned targetColumn, const T &scale)
Add scale multiples of the source column to the target column.
Definition Matrix.cpp:308
void insertColumns(unsigned pos, unsigned count)
Insert columns having positions pos, pos + 1, ... pos + count - 1.
Definition Matrix.cpp:152
void subtract(const MultiAffineFunction &other)
void removeOutputs(unsigned start, unsigned end)
Remove the specified range of outputs.
const PresburgerSpace & getSpace() const
Get the space of this function.
void print(raw_ostream &os) const
MultiAffineFunction(const PresburgerSpace &space, const IntMatrix &output)
PresburgerSpace getDomainSpace() const
Get the domain/output space of the function.
PresburgerSet getLexSet(OrderingKind comp, const MultiAffineFunction &other) const
Return the set of domain points where the output of this and other are ordered lexicographically acco...
SmallVector< DynamicAPInt, 8 > valueAt(ArrayRef< DynamicAPInt > point) const
void mergeDivs(MultiAffineFunction &other)
Given a MAF other, merges division variables such that both functions have the union of the division ...
ArrayRef< DynamicAPInt > getOutputExpr(unsigned i) const
Get the i^th output expression.
IntegerRelation getAsRelation() const
Get this function as a relation.
bool isEqual(const MultiAffineFunction &other) const
Return whether the this and other are equal when the domain is restricted to domain.
This class represents a piece-wise MultiAffineFunction.
void addPiece(const Piece &piece)
unsigned getNumDomainVars() const
void print(raw_ostream &os) const
PWMAFunction unionLexMax(const PWMAFunction &func)
void removeOutputs(unsigned start, unsigned end)
Remove the specified range of outputs.
const PresburgerSpace & getSpace() const
PWMAFunction unionLexMin(const PWMAFunction &func)
Return a function defined on the union of the domains of this and func, such that when only one of th...
PWMAFunction(const PresburgerSpace &space)
std::optional< SmallVector< DynamicAPInt, 8 > > valueAt(ArrayRef< DynamicAPInt > point) const
Return the output of the function at the given point.
PresburgerSet getDomain() const
Return the domain of this piece-wise MultiAffineFunction.
PresburgerSpace getDomainSpace() const
Get the domain/output space of the function.
unsigned getNumSymbolVars() const
bool isEqual(const PWMAFunction &other) const
Return whether this and other are equal as PWMAFunctions, i.e.
bool isIntegerEmpty() const
Return true if all the sets in the union are known to be integer empty false otherwise.
void unionInPlace(const IntegerRelation &disjunct)
Mutate this set, turning it into the union of this set and the given disjunct.
ArrayRef< IntegerRelation > getAllDisjuncts() const
Return a reference to the list of disjuncts.
PresburgerSet intersect(const PresburgerRelation &set) const
static PresburgerSet getEmpty(const PresburgerSpace &space)
Return an empty set of the specified type that contains no points.
PresburgerSpace is the space of all possible values of a tuple of integer valued variables/variables.
void removeVarRange(VarKind kind, unsigned varStart, unsigned varLimit)
Removes variables of the specified kind in the column range [varStart, varLimit).
PresburgerSpace getSpaceWithoutLocals() const
Get the space without local variables.
static PresburgerSpace getRelationSpace(unsigned numDomain=0, unsigned numRange=0, unsigned numSymbols=0, unsigned numLocals=0)
unsigned insertVar(VarKind kind, unsigned pos, unsigned num=1)
Insert num variables of the specified kind at position pos.
OrderingKind
Enum representing a binary comparison operator: equal, not equal, less than, less than or equal,...
SmallVector< DynamicAPInt, 8 > getDivUpperBound(ArrayRef< DynamicAPInt > dividend, const DynamicAPInt &divisor, unsigned localVarIdx)
If q is defined to be equal to expr floordiv d, this equivalent to saying that q is an integer and q ...
Definition Utils.cpp:315
SmallVector< DynamicAPInt, 8 > getDivLowerBound(ArrayRef< DynamicAPInt > dividend, const DynamicAPInt &divisor, unsigned localVarIdx)
Definition Utils.cpp:327
Include the generated interface declarations.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.