16 #include "llvm/ADT/STLExtras.h"
17 #include "llvm/ADT/SmallBitVector.h"
18 #include "llvm/ADT/SmallSet.h"
19 #include "llvm/ADT/SmallVector.h"
20 #include "llvm/ADT/StringRef.h"
21 #include "llvm/Support/raw_ostream.h"
25 #include <type_traits>
35 class AffineExprConstantFolder {
38 : numDims(numDims), operandConsts(operandConsts) {}
43 if (
auto result = constantFoldImpl(expr))
49 std::optional<int64_t> constantFoldImpl(
AffineExpr expr) {
52 return constantFoldBinExpr(
53 expr, [](int64_t lhs, int64_t rhs) {
return lhs + rhs; });
55 return constantFoldBinExpr(
56 expr, [](int64_t lhs, int64_t rhs) {
return lhs * rhs; });
58 return constantFoldBinExpr(
59 expr, [](int64_t lhs, int64_t rhs) {
return mod(lhs, rhs); });
61 return constantFoldBinExpr(
62 expr, [](int64_t lhs, int64_t rhs) {
return floorDiv(lhs, rhs); });
64 return constantFoldBinExpr(
65 expr, [](int64_t lhs, int64_t rhs) {
return ceilDiv(lhs, rhs); });
69 if (
auto attr = llvm::dyn_cast_or_null<IntegerAttr>(
74 if (
auto attr = llvm::dyn_cast_or_null<IntegerAttr>(
75 operandConsts[numDims +
80 llvm_unreachable(
"Unknown AffineExpr");
84 std::optional<int64_t> constantFoldBinExpr(
AffineExpr expr,
85 int64_t (*op)(int64_t, int64_t)) {
87 if (
auto lhs = constantFoldImpl(binOpExpr.getLHS()))
88 if (
auto rhs = constantFoldImpl(binOpExpr.getRHS()))
89 return op(*lhs, *rhs);
111 assert(dims >= results &&
"Dimension mismatch");
127 broadcastedDims->clear();
132 unsigned resIdx = idxAndExpr.index();
136 if (constExpr.getValue() != 0)
139 broadcastedDims->push_back(resIdx);
142 if (dimExpr.getPosition() != suffixStart + resIdx)
166 unsigned projectionStart =
168 permutedDims.clear();
174 unsigned leadingBroadcast =
179 unsigned resIdx = idxAndExpr.index();
184 if (constExpr.getValue() != 0)
186 broadcastDims.push_back(resIdx);
188 if (dimExpr.getPosition() < projectionStart)
190 unsigned newPosition =
191 dimExpr.getPosition() - projectionStart + leadingBroadcast;
192 permutedDims[resIdx] = newPosition;
193 dimFound[newPosition] =
true;
202 for (
auto dim : broadcastDims) {
203 while (pos < dimFound.size() && dimFound[pos]) {
206 permutedDims[dim] = pos++;
214 assert(!permutation.empty() &&
215 "Cannot create permutation map from empty permutation vector");
217 for (
auto index : permutation)
219 const auto *m = std::max_element(permutation.begin(), permutation.end());
220 auto permutationMap =
AffineMap::get(*m + 1, 0, affExprs, context);
221 assert(permutationMap.isPermutation() &&
"Invalid permutation vector");
222 return permutationMap;
225 template <
typename AffineExprContainer>
228 assert(!exprsList.empty());
229 assert(!exprsList[0].empty());
230 auto context = exprsList[0][0].getContext();
231 int64_t maxDim = -1, maxSym = -1;
234 maps.reserve(exprsList.size());
235 for (
const auto &exprs : exprsList)
237 maxSym + 1, exprs, context));
254 uint64_t thisGcd = resultExpr.getLargestKnownDivisor();
265 dimExprs.reserve(numDims);
266 for (
unsigned i = 0; i < numDims; ++i)
268 return get(numDims, 0, dimExprs, context);
277 for (
unsigned i = 0, numDims =
getNumDims(); i < numDims; ++i) {
279 if (!expr || expr.getPosition() != i)
289 for (
unsigned i = 0, numSymbols =
getNumSymbols(); i < numSymbols; ++i) {
291 if (!expr || expr.getPosition() != i)
317 assert(
isConstant() &&
"map must have only constant results");
325 assert(map &&
"uninitialized map storage");
329 assert(map &&
"uninitialized map storage");
334 assert(map &&
"uninitialized map storage");
338 assert(map &&
"uninitialized map storage");
353 for (
unsigned i = 0, numResults =
getNumResults(); i < numResults; i++) {
373 if (integers.empty())
376 auto range = llvm::map_range(integers, [
this](int64_t i) {
379 results.append(range.begin(), range.end());
389 AffineExprConstantFolder exprFolder(
getNumDims(), operandConstants);
394 auto folded = exprFolder.constantFold(expr);
401 results->push_back(folded.getInt());
403 exprs.push_back(expr);
428 unsigned numResultDims,
429 unsigned numResultSyms)
const {
435 return get(numResultDims, numResultSyms, results,
getContext());
442 unsigned numResultDims,
443 unsigned numResultSyms)
const {
447 newResults.push_back(e.replace(expr, replacement));
455 unsigned numResultDims,
456 unsigned numResultSyms)
const {
460 newResults.push_back(e.replace(map));
469 newResults.push_back(e.replace(map));
474 auto exprs = llvm::to_vector<4>(
getResults());
476 for (
auto pos = positions.find_last(); pos != -1;
477 pos = positions.find_prev(pos))
478 exprs.erase(exprs.begin() + pos);
487 unsigned numSymbols = numSymbolsThisMap + map.
getNumSymbols();
489 for (
unsigned idx = 0; idx < numDims; ++idx) {
493 for (
unsigned idx = numSymbolsThisMap; idx < numSymbols; ++idx) {
494 newSymbols[idx - numSymbolsThisMap] =
502 exprs.push_back(expr.
compose(newMap));
509 exprs.reserve(values.size());
511 for (
auto v : values)
515 res.reserve(resMap.getNumResults());
516 for (
auto e : resMap.getResults())
537 if (seen[dim.getPosition()])
539 seen[dim.getPosition()] =
true;
542 if (!allowZeroInResults || !constExpr || constExpr.getValue() != 0)
559 exprs.reserve(resultPos.size());
560 for (
auto idx : resultPos)
596 allExprs.reserve(maps.size() * maps.front().getNumResults());
597 unsigned numDims = maps.front().getNumDims(),
598 numSymbols = maps.front().getNumSymbols();
599 for (
auto m : maps) {
600 assert(numDims == m.getNumDims() && numSymbols == m.getNumSymbols() &&
601 "expected maps with same num dims and symbols");
602 llvm::append_range(allExprs, m.getResults());
605 AffineMap::get(numDims, numSymbols, allExprs, maps.front().getContext()));
606 unsigned unifiedNumDims = unifiedMap.
getNumDims(),
610 res.reserve(maps.size());
611 for (
auto m : maps) {
613 unifiedResults.take_front(m.getNumResults()),
615 unifiedResults = unifiedResults.drop_front(m.getNumResults());
621 const llvm::SmallBitVector &unusedDims) {
635 const llvm::SmallBitVector &unusedSymbols) {
661 uniqueExprs.erase(std::unique(uniqueExprs.begin(), uniqueExprs.end()),
670 assert(map.
getNumSymbols() == 0 &&
"expected map without symbols");
673 auto expr = en.value();
676 if (exprs[d.getPosition()])
683 for (
auto expr : exprs)
685 seenExprs.push_back(expr);
697 for (
unsigned i : llvm::seq(
unsigned(0), map.
getNumResults())) {
700 assert(constExpr.getValue() == 0 &&
701 "Unexpected constant in projected permutation");
713 unsigned numResults = 0, numDims = 0, numSymbols = 0;
717 results.reserve(numResults);
718 for (
auto m : maps) {
719 for (
auto res : m.getResults())
720 results.push_back(res.shiftSymbols(m.getNumSymbols(), numSymbols));
722 numSymbols += m.getNumSymbols();
723 numDims =
std::max(m.getNumDims(), numDims);
726 maps.front().getContext());
733 template <
typename AffineDimOrSymExpr>
735 const llvm::SmallBitVector &toProject,
737 static_assert(llvm::is_one_of<AffineDimOrSymExpr,
AffineDimExpr,
739 "expected AffineDimExpr or AffineSymbolExpr");
741 constexpr
bool isDim = std::is_same<AffineDimOrSymExpr, AffineDimExpr>::value;
744 replacements.reserve(numDimOrSym);
748 using replace_fn_ty =
754 replace_fn_ty replaceSymbols = [](
AffineExpr e,
758 replace_fn_ty replaceNewDimOrSym = (isDim) ? replaceDims : replaceSymbols;
761 int64_t newNumDimOrSym = 0;
762 for (
unsigned dimOrSym = 0; dimOrSym < numDimOrSym; ++dimOrSym) {
763 if (toProject.test(dimOrSym)) {
767 int64_t newPos = compress ? newNumDimOrSym++ : dimOrSym;
768 replacements.push_back(createNewDimOrSym(newPos, context));
773 resultExprs.push_back(replaceNewDimOrSym(e, replacements));
775 int64_t numDims = (compress && isDim) ? newNumDimOrSym : map.
getNumDims();
776 int64_t numSyms = (compress && !isDim) ? newNumDimOrSym : map.
getNumSymbols();
781 const llvm::SmallBitVector &projectedDimensions,
782 bool compressDimsFlag) {
783 return projectCommonImpl<AffineDimExpr>(map, projectedDimensions,
788 const llvm::SmallBitVector &projectedSymbols,
789 bool compressSymbolsFlag) {
790 return projectCommonImpl<AffineSymbolExpr>(map, projectedSymbols,
791 compressSymbolsFlag);
795 const llvm::SmallBitVector &projectedDimensions,
796 bool compressDimsFlag,
797 bool compressSymbolsFlag) {
798 map =
projectDims(map, projectedDimensions, compressDimsFlag);
799 if (compressSymbolsFlag)
805 unsigned numDims = maps[0].getNumDims();
806 llvm::SmallBitVector numDimsBitVector(numDims,
true);
808 for (
unsigned i = 0; i < numDims; ++i) {
809 if (m.isFunctionOfDim(i))
810 numDimsBitVector.reset(i);
813 return numDimsBitVector;
817 unsigned numSymbols = maps[0].getNumSymbols();
818 llvm::SmallBitVector numSymbolsBitVector(numSymbols,
true);
820 for (
unsigned i = 0; i < numSymbols; ++i) {
821 if (m.isFunctionOfSymbol(i))
822 numSymbolsBitVector.reset(i);
825 return numSymbolsBitVector;
830 const llvm::SmallBitVector &projectedDimensions) {
841 : results(map.getResults().begin(), map.getResults().end()),
842 numDims(map.getNumDims()), numSymbols(map.getNumSymbols()),
843 context(map.getContext()) {}
850 llvm::append_range(results, map.
getResults());
static SmallVector< AffineMap > compressUnusedListImpl(ArrayRef< AffineMap > maps, llvm::function_ref< AffineMap(AffineMap)> compressionFun)
Implementation detail to compress multiple affine maps with a compressionFun that is expected to be e...
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< AffineExprContainer > exprsList)
static AffineMap projectCommonImpl(AffineMap map, const llvm::SmallBitVector &toProject, bool compress)
Common implementation to project out dimensions or symbols from an affine map based on the template t...
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
Affine binary operation expression.
An integer constant appearing in affine expression.
A dimensional identifier appearing in an affine expression.
unsigned getPosition() const
Base type for affine expression.
AffineExpr replaceDimsAndSymbols(ArrayRef< AffineExpr > dimReplacements, ArrayRef< AffineExpr > symReplacements) const
This method substitutes any uses of dimensions and symbols (e.g.
void walk(std::function< void(AffineExpr)> callback) const
Walk all of the AffineExpr's in this expression in postorder.
AffineExprKind getKind() const
Return the classification for this type.
AffineExpr compose(AffineMap map) const
Compose with an AffineMap.
constexpr bool isa() const
AffineExpr replaceDims(ArrayRef< AffineExpr > dimReplacements) const
Dim-only version of replaceDimsAndSymbols.
MLIRContext * getContext() const
AffineExpr replaceSymbols(ArrayRef< AffineExpr > symReplacements) const
Symbol-only version of replaceDimsAndSymbols.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
int64_t getSingleConstantResult() const
Returns the constant result of this map.
static AffineMap getMinorIdentityMap(unsigned dims, unsigned results, MLIRContext *context)
Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
AffineMap dropResults(ArrayRef< int64_t > positions) const
AffineMap getSliceMap(unsigned start, unsigned length) const
Returns the map consisting of length expressions starting from start.
AffineMap getMajorSubMap(unsigned numResults) const
Returns the map consisting of the most major numResults results.
MLIRContext * getContext() const
AffineMap partialConstantFold(ArrayRef< Attribute > operandConstants, SmallVectorImpl< int64_t > *results=nullptr) const
Propagates the constant operands into this affine map.
bool isMinorIdentity() const
Returns true if this affine map is a minor identity, i.e.
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
bool isConstant() const
Returns true if this affine map has only constant results.
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isSingleConstant() const
Returns true if this affine map is a single result constant function.
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
AffineMap getMinorSubMap(unsigned numResults) const
Returns the map consisting of the most minor numResults results.
uint64_t getLargestKnownDivisorOfMapExprs()
Get the largest known divisor of all map expressions.
constexpr AffineMap()=default
bool isEmpty() const
Returns true if this affine map is an empty map, i.e., () -> ().
std::optional< unsigned > getResultPosition(AffineExpr input) const
Extracts the first result position where input dimension resides.
unsigned getNumSymbols() const
bool isMinorIdentityWithBroadcasting(SmallVectorImpl< unsigned > *broadcastedDims=nullptr) const
Returns true if this affine map is a minor identity up to broadcasted dimensions which are indicated ...
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
SmallVector< int64_t > getConstantResults() const
Returns the constant results of this map.
bool isPermutationOfMinorIdentityWithBroadcasting(SmallVectorImpl< unsigned > &permutedDims) const
Return true if this affine map can be converted to a minor identity with broadcast by doing a permute...
LogicalResult constantFold(ArrayRef< Attribute > operandConstants, SmallVectorImpl< Attribute > &results) const
Folds the results of the application of an affine map on the provided operands to a constant if possi...
bool isSymbolIdentity() const
Returns true if this affine map is an identity affine map on the symbol identifiers.
unsigned getNumResults() const
AffineMap replaceDimsAndSymbols(ArrayRef< AffineExpr > dimReplacements, ArrayRef< AffineExpr > symReplacements, unsigned numResultDims, unsigned numResultSyms) const
This method substitutes any uses of dimensions and symbols (e.g.
unsigned getNumInputs() const
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
AffineExpr getResult(unsigned idx) const
AffineMap replace(AffineExpr expr, AffineExpr replacement, unsigned numResultDims, unsigned numResultSyms) const
Sparse replace method.
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
void walkExprs(llvm::function_ref< void(AffineExpr)> callback) const
Walk all of the AffineExpr's in this mapping.
static AffineMap getConstantMap(int64_t val, MLIRContext *context)
Returns a single constant result affine map.
AffineMap getSubMap(ArrayRef< unsigned > resultPos) const
Returns the map consisting of the resultPos subset.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
bool isIdentity() const
Returns true if this affine map is an identity affine map.
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
A symbolic identifier appearing in an affine expression.
unsigned getPosition() const
MLIRContext is the top-level object for a collection of MLIR operations.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt gcd(const MPInt &a, const MPInt &b)
This header declares functions that assist transformations in the MemRef dialect.
AffineMap simplifyAffineMap(AffineMap map)
Simplifies an affine map by simplifying its underlying AffineExpr results.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
AffineMap expandDimsToRank(AffineMap map, int64_t rank, const llvm::SmallBitVector &projectedDimensions)
Expand map to operate on rank dims while projecting out the dims in projectedDimensions.
AffineMap removeDuplicateExprs(AffineMap map)
Returns a map with the same dimension and symbol count as map, but whose results are the unique affin...
llvm::SmallBitVector getUnusedSymbolsBitVector(ArrayRef< AffineMap > maps)
AffineMap inverseAndBroadcastProjectedPermutation(AffineMap map)
Return the reverse map of a projected permutation where the projected dimensions are transformed into...
int64_t floorDiv(int64_t lhs, int64_t rhs)
Returns the result of MLIR's floordiv operation on constants.
int64_t ceilDiv(int64_t lhs, int64_t rhs)
Returns the result of MLIR's ceildiv operation on constants.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
AffineMap concatAffineMaps(ArrayRef< AffineMap > maps)
Concatenates a list of maps into a single AffineMap, stepping over potentially empty maps.
@ CeilDiv
RHS of ceildiv is always a constant or a symbolic expression.
@ Mul
RHS of mul is always a constant or a symbolic expression.
@ Mod
RHS of mod is always a constant or a symbolic expression with a positive value.
@ DimId
Dimensional identifier.
@ FloorDiv
RHS of floordiv is always a constant or a symbolic expression.
@ Constant
Constant integer.
@ SymbolId
Symbolic identifier.
AffineMap compressSymbols(AffineMap map, const llvm::SmallBitVector &unusedSymbols)
Drop the symbols that are listed in unusedSymbols.
static void getMaxDimAndSymbol(ArrayRef< AffineExprContainer > exprsList, int64_t &maxDim, int64_t &maxSym)
Calculates maximum dimension and symbol positions from the expressions in exprsLists and stores them ...
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
AffineMap compressDims(AffineMap map, const llvm::SmallBitVector &unusedDims)
Drop the dims that are listed in unusedDims.
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
AffineMap getProjectedMap(AffineMap map, const llvm::SmallBitVector &projectedDimensions, bool compressDimsFlag=true, bool compressSymbolsFlag=true)
Calls projectDims(map, projectedDimensions, compressDimsFlag).
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::SmallBitVector getUnusedDimsBitVector(ArrayRef< AffineMap > maps)
AffineExpr simplifyAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols)
Simplify an affine expression by flattening and some amount of simple analysis.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
AffineMap projectDims(AffineMap map, const llvm::SmallBitVector &projectedDimensions, bool compressDimsFlag=false)
Returns the map that results from projecting out the dimensions specified in projectedDimensions.
AffineMap compressUnusedSymbols(AffineMap map)
Drop the symbols that are not used.
AffineMap projectSymbols(AffineMap map, const llvm::SmallBitVector &projectedSymbols, bool compressSymbolsFlag=false)
Symbol counterpart of projectDims.
AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context)
int64_t mod(int64_t lhs, int64_t rhs)
Returns MLIR's mod operation on constants.
This class represents an efficient way to signal success or failure.
void reset(AffineMap map)
Resets this MutableAffineMap with 'map'.
MutableAffineMap()=default
AffineMap getAffineMap() const
Get the AffineMap corresponding to this MutableAffineMap.
AffineExpr getResult(unsigned idx) const
bool isMultipleOf(unsigned idx, int64_t factor) const
Returns true if the idx'th result expression is a multiple of factor.
unsigned getNumResults() const
void simplify()
Simplify the (result) expressions in this map using analysis (used by.
ArrayRef< AffineExpr > results() const
The affine expressions for this (multi-dimensional) map.