15 #include "llvm/ADT/SmallBitVector.h" 16 #include "llvm/ADT/SmallSet.h" 17 #include "llvm/ADT/StringRef.h" 18 #include "llvm/Support/raw_ostream.h" 28 class AffineExprConstantFolder {
31 : numDims(numDims), operandConsts(operandConsts) {}
36 if (
auto result = constantFoldImpl(expr))
37 return IntegerAttr::get(IndexType::get(expr.
getContext()), *result);
45 return constantFoldBinExpr(
46 expr, [](int64_t lhs, int64_t rhs) {
return lhs + rhs; });
48 return constantFoldBinExpr(
49 expr, [](int64_t lhs, int64_t rhs) {
return lhs * rhs; });
51 return constantFoldBinExpr(
52 expr, [](int64_t lhs, int64_t rhs) {
return mod(lhs, rhs); });
54 return constantFoldBinExpr(
55 expr, [](int64_t lhs, int64_t rhs) {
return floorDiv(lhs, rhs); });
57 return constantFoldBinExpr(
58 expr, [](int64_t lhs, int64_t rhs) {
return ceilDiv(lhs, rhs); });
63 .dyn_cast_or_null<IntegerAttr>())
67 if (
auto attr = operandConsts[numDims +
69 .dyn_cast_or_null<IntegerAttr>())
73 llvm_unreachable(
"Unknown AffineExpr");
78 int64_t (*op)(int64_t, int64_t)) {
80 if (
auto lhs = constantFoldImpl(binOpExpr.getLHS()))
81 if (
auto rhs = constantFoldImpl(binOpExpr.getRHS()))
82 return op(*lhs, *rhs);
104 assert(dims >= results &&
"Dimension mismatch");
106 return AffineMap::get(dims, 0,
id.getResults().take_back(results), context);
110 return getNumDims() >= getNumResults() &&
112 getMinorIdentityMap(getNumDims(), getNumResults(), getContext());
120 broadcastedDims->clear();
121 if (getNumDims() < getNumResults())
123 unsigned suffixStart = getNumDims() - getNumResults();
125 unsigned resIdx = idxAndExpr.index();
129 if (constExpr.getValue() != 0)
132 broadcastedDims->push_back(resIdx);
135 if (dimExpr.getPosition() != suffixStart + resIdx)
159 unsigned projectionStart =
160 getNumResults() < getNumInputs() ? getNumInputs() - getNumResults() : 0;
161 permutedDims.clear();
163 permutedDims.resize(getNumResults(), 0);
167 unsigned leadingBroadcast =
168 getNumResults() > getNumInputs() ? getNumResults() - getNumInputs() : 0;
169 llvm::SmallBitVector dimFound(
std::max(getNumInputs(), getNumResults()),
172 unsigned resIdx = idxAndExpr.index();
177 if (constExpr.getValue() != 0)
179 broadcastDims.push_back(resIdx);
181 if (dimExpr.getPosition() < projectionStart)
183 unsigned newPosition =
184 dimExpr.getPosition() - projectionStart + leadingBroadcast;
185 permutedDims[resIdx] = newPosition;
186 dimFound[newPosition] =
true;
195 for (
auto dim : broadcastDims) {
196 while (pos < dimFound.size() && dimFound[pos]) {
199 permutedDims[dim] = pos++;
207 assert(!permutation.empty() &&
208 "Cannot create permutation map from empty permutation vector");
210 for (
auto index : permutation)
212 const auto *m = std::max_element(permutation.begin(), permutation.end());
213 auto permutationMap =
AffineMap::get(*m + 1, 0, affExprs, context);
214 assert(permutationMap.isPermutation() &&
"Invalid permutation vector");
215 return permutationMap;
218 template <
typename AffineExprContainer>
221 assert(!exprsList.empty());
222 assert(!exprsList[0].empty());
223 auto context = exprsList[0][0].getContext();
224 int64_t maxDim = -1, maxSym = -1;
227 maps.reserve(exprsList.size());
228 for (
const auto &exprs : exprsList)
230 maxSym + 1, exprs, context));
247 dimExprs.reserve(numDims);
248 for (
unsigned i = 0; i < numDims; ++i)
250 return get(numDims, 0, dimExprs, context);
256 if (getNumDims() != getNumResults())
259 for (
unsigned i = 0, numDims = getNumDims(); i < numDims; ++i) {
261 if (!expr || expr.getPosition() != i)
268 return getNumDims() == 0 && getNumSymbols() == 0 && getNumResults() == 0;
276 return llvm::all_of(getResults(), [](
AffineExpr expr) {
282 assert(isSingleConstant() &&
"map must have a single constant result");
287 assert(isConstant() &&
"map must have only constant results");
289 for (
auto expr : getResults())
295 assert(map &&
"uninitialized map storage");
299 assert(map &&
"uninitialized map storage");
300 return map->numSymbols;
304 assert(map &&
"uninitialized map storage");
305 return map->numDims + map->numSymbols;
308 assert(map &&
"uninitialized map storage");
309 return map->results();
312 return getResults()[idx];
321 for (
unsigned i = 0, numResults = getNumResults(); i < numResults; i++)
324 llvm_unreachable(
"incorrect permutation request");
335 partialConstantFold(operandConstants, &integers);
339 if (integers.empty())
342 auto range = llvm::map_range(integers, [
this](int64_t i) {
343 return IntegerAttr::get(IndexType::get(getContext()), i);
345 results.append(range.begin(), range.end());
352 assert(getNumInputs() == operandConstants.size());
355 AffineExprConstantFolder exprFolder(getNumDims(), operandConstants);
357 exprs.reserve(getNumResults());
359 for (
auto expr : getResults()) {
360 auto folded = exprFolder.constantFold(expr);
367 results->push_back(folded.getInt());
369 exprs.push_back(expr);
377 return get(getNumDims(), getNumSymbols(), exprs, getContext());
383 for (
auto expr : getResults())
394 unsigned numResultDims,
395 unsigned numResultSyms)
const {
397 results.reserve(getNumResults());
398 for (
auto expr : getResults())
401 return get(numResultDims, numResultSyms, results, getContext());
408 unsigned numResultDims,
409 unsigned numResultSyms)
const {
411 newResults.reserve(getNumResults());
413 newResults.push_back(e.replace(expr, replacement));
414 return AffineMap::get(numResultDims, numResultSyms, newResults, getContext());
421 unsigned numResultDims,
422 unsigned numResultSyms)
const {
424 newResults.reserve(getNumResults());
426 newResults.push_back(e.replace(map));
427 return AffineMap::get(numResultDims, numResultSyms, newResults, getContext());
433 newResults.reserve(getNumResults());
435 newResults.push_back(e.replace(map));
440 assert(getNumDims() == map.
getNumResults() &&
"Number of results mismatch");
443 unsigned numSymbolsThisMap = getNumSymbols();
444 unsigned numSymbols = numSymbolsThisMap + map.
getNumSymbols();
446 for (
unsigned idx = 0; idx < numDims; ++idx) {
450 for (
unsigned idx = numSymbolsThisMap; idx < numSymbols; ++idx) {
451 newSymbols[idx - numSymbolsThisMap] =
457 exprs.reserve(getResults().size());
458 for (
auto expr : getResults())
459 exprs.push_back(expr.
compose(newMap));
464 assert(getNumSymbols() == 0 &&
"Expected symbol-less map");
466 exprs.reserve(values.size());
468 for (
auto v : values)
472 res.reserve(resMap.getNumResults());
473 for (
auto e : resMap.getResults())
479 if (getNumSymbols() > 0)
484 if (getNumResults() > getNumInputs())
492 for (
auto expr : getResults()) {
494 if (seen[dim.getPosition()])
496 seen[dim.getPosition()] =
true;
499 if (!allowZeroInResults || !constExpr || constExpr.getValue() != 0)
509 if (getNumDims() != getNumResults())
511 return isProjectedPermutation();
516 exprs.reserve(resultPos.size());
517 for (
auto idx : resultPos)
518 exprs.push_back(getResult(idx));
519 return AffineMap::get(getNumDims(), getNumSymbols(), exprs, getContext());
524 getResults().slice(start, length), getContext());
530 if (numResults > getNumResults())
532 return getSliceMap(0, numResults);
538 if (numResults > getNumResults())
540 return getSliceMap(getNumResults() - numResults, numResults);
544 const llvm::SmallBitVector &unusedDims) {
545 unsigned numDims = 0;
549 for (
unsigned dim = 0, e = map.
getNumDims(); dim < e; ++dim) {
550 if (unusedDims.test(dim))
558 resultExprs.push_back(e.replaceDims(dimReplacements));
563 llvm::SmallBitVector unusedDims(map.
getNumDims(),
true);
566 unusedDims.reset(dimExpr.getPosition());
577 allExprs.reserve(maps.size() * maps.front().getNumResults());
578 unsigned numDims = maps.front().getNumDims(),
579 numSymbols = maps.front().getNumSymbols();
580 for (
auto m : maps) {
581 assert(numDims == m.getNumDims() && numSymbols == m.getNumSymbols() &&
582 "expected maps with same num dims and symbols");
583 llvm::append_range(allExprs, m.getResults());
586 AffineMap::get(numDims, numSymbols, allExprs, maps.front().getContext()));
587 unsigned unifiedNumDims = unifiedMap.getNumDims(),
588 unifiedNumSymbols = unifiedMap.getNumSymbols();
591 res.reserve(maps.size());
592 for (
auto m : maps) {
594 unifiedResults.take_front(m.getNumResults()),
596 unifiedResults = unifiedResults.drop_front(m.getNumResults());
607 const llvm::SmallBitVector &unusedSymbols) {
608 unsigned numSymbols = 0;
612 for (
unsigned sym = 0, e = map.
getNumSymbols(); sym < e; ++sym) {
613 if (unusedSymbols.test(sym))
621 resultExprs.push_back(e.replaceSymbols(symReplacements));
626 llvm::SmallBitVector unusedSymbols(map.
getNumSymbols(),
true);
629 unusedSymbols.reset(symExpr.getPosition());
652 uniqueExprs.erase(std::unique(uniqueExprs.begin(), uniqueExprs.end()),
661 assert(map.
getNumSymbols() == 0 &&
"expected map without symbols");
664 auto expr = en.value();
667 if (exprs[d.getPosition()])
674 for (
auto expr : exprs)
676 seenExprs.push_back(expr);
688 for (
unsigned i : llvm::seq(
unsigned(0), map.
getNumResults())) {
691 assert(constExpr.getValue() == 0 &&
692 "Unexpected constant in projected permutation");
704 unsigned numResults = 0, numDims = 0, numSymbols = 0;
706 numResults += m.getNumResults();
708 results.reserve(numResults);
709 for (
auto m : maps) {
710 for (
auto res : m.getResults())
711 results.push_back(res.shiftSymbols(m.getNumSymbols(), numSymbols));
713 numSymbols += m.getNumSymbols();
714 numDims =
std::max(m.getNumDims(), numDims);
717 maps.front().getContext());
721 const llvm::SmallBitVector &unusedDims) {
730 : results(map.getResults().begin(), map.getResults().end()),
731 numDims(map.getNumDims()), numSymbols(map.getNumSymbols()),
732 context(map.getContext()) {}
739 llvm::append_range(results, map.
getResults());
Affine binary operation expression.
Include the generated interface declarations.
unsigned getPermutedPosition(unsigned input) const
Extracts the permuted position where given input index resides.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
bool isConstant() const
Returns true if this affine map has only constant results.
RHS of mod is always a constant or a symbolic expression with a positive value.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
AffineExpr replaceDimsAndSymbols(ArrayRef< AffineExpr > dimReplacements, ArrayRef< AffineExpr > symReplacements) const
This method substitutes any uses of dimensions and symbols (e.g.
static void getMaxDimAndSymbol(ArrayRef< AffineExprContainer > exprsList, int64_t &maxDim, int64_t &maxSym)
Calculates maxmimum dimension and symbol positions from the expressions in exprsLists and stores them...
AffineMap getSubMap(ArrayRef< unsigned > resultPos) const
Returns the map consisting of the resultPos subset.
unsigned getNumSymbols() const
unsigned getNumDims() const
bool isMinorIdentityWithBroadcasting(SmallVectorImpl< unsigned > *broadcastedDims=nullptr) const
Returns true if this affine map is a minor identity up to broadcasted dimensions which are indicated ...
AffineExpr compose(AffineMap map) const
Compose with an AffineMap.
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
bool isEmpty() const
Returns true if this affine map is an empty map, i.e., () -> ().
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
unsigned getNumResults() const
void simplify()
Simplify the (result) expressions in this map using analysis (used by.
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
static bool isPermutation(std::vector< PermutationTy > permutation)
bool isMinorIdentity() const
Returns true if this affine map is a minor identity, i.e.
unsigned getPosition() const
An integer constant appearing in affine expression.
void walk(std::function< void(AffineExpr)> callback) const
Walk all of the AffineExpr's in this expression in postorder.
AffineMap compressSymbols(AffineMap map, const llvm::SmallBitVector &unusedSymbols)
Drop the symbols that are not listed in unusedSymbols.
AffineExpr simplifyAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols)
Simplify an affine expression by flattening and some amount of simple analysis.
AffineExpr getResult(unsigned idx) const
static AffineMap getConstantMap(int64_t val, MLIRContext *context)
Returns a single constant result affine map.
unsigned getNumInputs() const
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
This class represents an efficient way to signal success or failure.
AffineMap removeDuplicateExprs(AffineMap map)
Returns a map with the same dimension and symbol count as map, but whose results are the unique affin...
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
AffineMap getSliceMap(unsigned start, unsigned length) const
Returns the map consisting of length expressions starting from start.
int64_t floorDiv(int64_t lhs, int64_t rhs)
Returns the result of MLIR's floordiv operation on constants.
int64_t getSingleConstantResult() const
Returns the constant result of this map.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context)
int64_t ceilDiv(int64_t lhs, int64_t rhs)
Returns the result of MLIR's ceildiv operation on constants.
AffineMap replaceDimsAndSymbols(ArrayRef< AffineExpr > dimReplacements, ArrayRef< AffineExpr > symReplacements, unsigned numResultDims, unsigned numResultSyms) const
This method substitutes any uses of dimensions and symbols (e.g.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
AffineMap concatAffineMaps(ArrayRef< AffineMap > maps)
Concatenates a list of maps into a single AffineMap, stepping over potentially empty maps...
Base type for affine expression.
MLIRContext * getContext() const
LogicalResult constantFold(ArrayRef< Attribute > operandConstants, SmallVectorImpl< Attribute > &results) const
Folds the results of the application of an affine map on the provided operands to a constant if possi...
int64_t mod(int64_t lhs, int64_t rhs)
Returns MLIR's mod operation on constants.
RHS of mul is always a constant or a symbolic expression.
unsigned getNumResults() const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued...
MutableAffineMap()=default
RHS of floordiv is always a constant or a symbolic expression.
static AffineMap getMinorIdentityMap(unsigned dims, unsigned results, MLIRContext *context)
Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
ArrayRef< AffineExpr > getResults() const
AffineMap replace(AffineExpr expr, AffineExpr replacement, unsigned numResultDims, unsigned numResultSyms) const
Sparse replace method.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
AffineMap partialConstantFold(ArrayRef< Attribute > operandConstants, SmallVectorImpl< int64_t > *results=nullptr) const
Propagates the constant operands into this affine map.
AffineMap inverseAndBroadcastProjectedPermutation(AffineMap map)
Return the reverse map of a projected permutation where the projected dimensions are transformed into...
RHS of ceildiv is always a constant or a symbolic expression.
unsigned getPosition() const
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
AffineMap getMinorSubMap(unsigned numResults) const
Returns the map consisting of the most minor numResults results.
bool isPermutationOfMinorIdentityWithBroadcasting(SmallVectorImpl< unsigned > &permutedDims) const
Return true if this affine map can be converted to a minor identity with broadcast by doing a permute...
AffineExpr getResult(unsigned idx) const
AffineExprKind getKind() const
Return the classification for this type.
AffineMap simplifyAffineMap(AffineMap map)
Simplifies an affine map by simplifying its underlying AffineExpr results.
AffineMap compressDims(AffineMap map, const llvm::SmallBitVector &unusedDims)
Drop the dims that are not listed in unusedDims.
AffineMap getProjectedMap(AffineMap map, const llvm::SmallBitVector &projectedDimensions)
Returns the map that results from projecting out the dimensions specified in projectedDimensions.
A dimensional identifier appearing in an affine expression.
bool isIdentity() const
Returns true if this affine map is an identity affine map.
bool isMultipleOf(unsigned idx, int64_t factor) const
Returns true if the idx'th result expression is a multiple of factor.
MLIRContext is the top-level object for a collection of MLIR operations.
SmallVector< int64_t > getConstantResults() const
Returns the constant results of this map.
void walkExprs(llvm::function_ref< void(AffineExpr)> callback) const
Walk all of the AffineExpr's in this mapping.
AffineMap getAffineMap() const
Get the AffineMap corresponding to this MutableAffineMap.
AffineMap getMajorSubMap(unsigned numResults) const
Returns the map consisting of the most major numResults results.
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
MLIRContext * getContext() const
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
AffineMap compressUnusedSymbols(AffineMap map)
Drop the symbols that are not used.
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
bool isSingleConstant() const
Returns true if this affine map is a single result constant function.
static SmallVector< AffineMap > compressUnusedImpl(ArrayRef< AffineMap > maps, llvm::function_ref< AffineMap(AffineMap)> compressionFun)
void reset(AffineMap map)
Resets this MutableAffineMap with 'map'.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
A symbolic identifier appearing in an affine expression.
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< AffineExprContainer > exprsList)