15 #include "llvm/ADT/STLExtras.h"
16 #include "llvm/ADT/SmallBitVector.h"
17 #include "llvm/ADT/SmallSet.h"
18 #include "llvm/ADT/SmallVector.h"
19 #include "llvm/ADT/StringRef.h"
20 #include "llvm/Support/MathExtras.h"
21 #include "llvm/Support/raw_ostream.h"
25 #include <type_traits>
29 using llvm::divideCeilSigned;
30 using llvm::divideFloorSigned;
39 class AffineExprConstantFolder {
42 : numDims(numDims), operandConsts(operandConsts) {}
47 if (
auto result = constantFoldImpl(expr))
52 bool hasPoison()
const {
return hasPoison_; }
55 std::optional<int64_t> constantFoldImpl(
AffineExpr expr) {
58 return constantFoldBinExpr(
59 expr, [](int64_t lhs, int64_t rhs) {
return lhs + rhs; });
61 return constantFoldBinExpr(
62 expr, [](int64_t lhs, int64_t rhs) {
return lhs * rhs; });
64 return constantFoldBinExpr(
65 expr, [
this](int64_t lhs, int64_t rhs) -> std::optional<int64_t> {
73 return constantFoldBinExpr(
74 expr, [
this](int64_t lhs, int64_t rhs) -> std::optional<int64_t> {
79 return divideFloorSigned(lhs, rhs);
82 return constantFoldBinExpr(
83 expr, [
this](int64_t lhs, int64_t rhs) -> std::optional<int64_t> {
88 return divideCeilSigned(lhs, rhs);
91 return cast<AffineConstantExpr>(expr).getValue();
93 if (
auto attr = llvm::dyn_cast_or_null<IntegerAttr>(
94 operandConsts[cast<AffineDimExpr>(expr).getPosition()]))
98 if (
auto attr = llvm::dyn_cast_or_null<IntegerAttr>(
99 operandConsts[numDims +
100 cast<AffineSymbolExpr>(expr).getPosition()]))
101 return attr.getInt();
104 llvm_unreachable(
"Unknown AffineExpr");
108 std::optional<int64_t> constantFoldBinExpr(
111 auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
112 if (
auto lhs = constantFoldImpl(binOpExpr.getLHS()))
113 if (
auto rhs = constantFoldImpl(binOpExpr.getRHS()))
114 return op(*lhs, *rhs);
122 bool hasPoison_{
false};
137 assert(dims >= results &&
"Dimension mismatch");
148 llvm::SmallBitVector dropDimResults(numDims);
149 for (
auto [idx, resultExpr] :
llvm::enumerate(identityMap.getResults()))
150 dropDimResults[idx] = !keepDimFilter(cast<AffineDimExpr>(resultExpr));
152 return identityMap.dropResults(dropDimResults);
164 if (
auto constExpr = dyn_cast<AffineConstantExpr>(expr)) {
165 if (constExpr.getValue() != 0)
167 broadcastedDims.push_back(resIdx);
171 return broadcastedDims;
179 broadcastedDims->clear();
184 unsigned resIdx = idxAndExpr.index();
186 if (
auto constExpr = dyn_cast<AffineConstantExpr>(expr)) {
188 if (constExpr.getValue() != 0)
191 broadcastedDims->push_back(resIdx);
192 }
else if (
auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
194 if (dimExpr.getPosition() != suffixStart + resIdx)
218 unsigned projectionStart =
220 permutedDims.clear();
226 unsigned leadingBroadcast =
231 unsigned resIdx = idxAndExpr.index();
235 if (
auto constExpr = dyn_cast<AffineConstantExpr>(expr)) {
236 if (constExpr.getValue() != 0)
238 broadcastDims.push_back(resIdx);
239 }
else if (
auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
240 if (dimExpr.getPosition() < projectionStart)
242 unsigned newPosition =
243 dimExpr.getPosition() - projectionStart + leadingBroadcast;
244 permutedDims[resIdx] = newPosition;
245 dimFound[newPosition] =
true;
254 for (
auto dim : broadcastDims) {
255 while (pos < dimFound.size() && dimFound[pos]) {
258 permutedDims[dim] = pos++;
266 assert(!permutation.empty() &&
267 "Cannot create permutation map from empty permutation vector");
268 const auto *m = llvm::max_element(permutation);
270 assert(permutationMap.isPermutation() &&
"Invalid permutation vector");
271 return permutationMap;
276 permutation, [](int64_t i) {
return static_cast<unsigned>(i); });
284 for (
unsigned t : targets)
295 template <
typename AffineExprContainer>
299 if (exprsList.empty())
301 int64_t maxDim = -1, maxSym = -1;
304 maps.reserve(exprsList.size());
305 for (
const auto &exprs : exprsList)
307 maxSym + 1, exprs, context));
326 uint64_t thisGcd = resultExpr.getLargestKnownDivisor();
327 gcd = std::gcd(gcd, thisGcd);
337 dimExprs.reserve(numDims);
338 for (
unsigned i = 0; i < numDims; ++i)
340 return get(numDims, 0, dimExprs, context);
349 for (
unsigned i = 0, numDims =
getNumDims(); i < numDims; ++i) {
350 auto expr = dyn_cast<AffineDimExpr>(results[i]);
351 if (!expr || expr.getPosition() != i)
361 for (
unsigned i = 0, numSymbols =
getNumSymbols(); i < numSymbols; ++i) {
362 auto expr = dyn_cast<AffineDimExpr>(results[i]);
363 if (!expr || expr.getPosition() != i)
378 return llvm::all_of(
getResults(), llvm::IsaPred<AffineConstantExpr>);
383 return cast<AffineConstantExpr>(
getResult(0)).getValue();
387 assert(
isConstant() &&
"map must have only constant results");
390 result.emplace_back(cast<AffineConstantExpr>(expr).getValue());
395 assert(map &&
"uninitialized map storage");
399 assert(map &&
"uninitialized map storage");
404 assert(map &&
"uninitialized map storage");
408 assert(map &&
"uninitialized map storage");
416 return cast<AffineDimExpr>(
getResult(idx)).getPosition();
420 if (!isa<AffineDimExpr>(input))
423 for (
unsigned i = 0, numResults =
getNumResults(); i < numResults; i++) {
436 bool *hasPoison)
const {
443 if (integers.empty())
446 auto range = llvm::map_range(integers, [
this](int64_t i) {
449 results.append(range.begin(), range.end());
455 bool *hasPoison)
const {
459 AffineExprConstantFolder exprFolder(
getNumDims(), operandConstants);
464 auto folded = exprFolder.constantFold(expr);
465 if (exprFolder.hasPoison() && hasPoison) {
475 results->push_back(folded.getInt());
477 exprs.push_back(expr);
502 unsigned numResultDims,
503 unsigned numResultSyms)
const {
509 return get(numResultDims, numResultSyms, results,
getContext());
516 unsigned numResultDims,
517 unsigned numResultSyms)
const {
521 newResults.push_back(e.replace(expr, replacement));
529 unsigned numResultDims,
530 unsigned numResultSyms)
const {
534 newResults.push_back(e.replace(map));
543 newResults.push_back(e.replace(map));
548 auto exprs = llvm::to_vector<4>(
getResults());
550 for (
auto pos = positions.find_last(); pos != -1;
551 pos = positions.find_prev(pos))
552 exprs.erase(exprs.begin() + pos);
561 unsigned numSymbols = numSymbolsThisMap + map.
getNumSymbols();
563 for (
unsigned idx = 0; idx < numDims; ++idx) {
567 for (
unsigned idx = numSymbolsThisMap; idx < numSymbols; ++idx) {
568 newSymbols[idx - numSymbolsThisMap] =
576 exprs.push_back(expr.
compose(newMap));
583 exprs.reserve(values.size());
585 for (
auto v : values)
589 res.reserve(resMap.getNumResults());
590 for (
auto e : resMap.getResults())
591 res.push_back(cast<AffineConstantExpr>(e).getValue());
610 if (
auto dim = dyn_cast<AffineDimExpr>(expr)) {
611 if (seen[dim.getPosition()])
613 seen[dim.getPosition()] =
true;
615 auto constExpr = dyn_cast<AffineConstantExpr>(expr);
616 if (!allowZeroInResults || !constExpr || constExpr.getValue() != 0)
633 exprs.reserve(resultPos.size());
634 for (
auto idx : resultPos)
670 allExprs.reserve(maps.size() * maps.front().getNumResults());
671 unsigned numDims = maps.front().getNumDims(),
672 numSymbols = maps.front().getNumSymbols();
673 for (
auto m : maps) {
674 assert(numDims == m.getNumDims() && numSymbols == m.getNumSymbols() &&
675 "expected maps with same num dims and symbols");
676 llvm::append_range(allExprs, m.getResults());
679 AffineMap::get(numDims, numSymbols, allExprs, maps.front().getContext()));
680 unsigned unifiedNumDims = unifiedMap.
getNumDims(),
684 res.reserve(maps.size());
685 for (
auto m : maps) {
687 unifiedResults.take_front(m.getNumResults()),
689 unifiedResults = unifiedResults.drop_front(m.getNumResults());
695 const llvm::SmallBitVector &unusedDims) {
709 const llvm::SmallBitVector &unusedSymbols) {
727 for (int64_t i = 0; i < map.
getNumDims(); ++i) {
728 if (
auto attr = operands[i].dyn_cast<Attribute>()) {
729 dimReplacements.push_back(
733 remainingValues.push_back(operands[i].get<Value>());
736 int64_t numSymbols = 0;
739 symReplacements.push_back(
743 remainingValues.push_back(operands[i + map.
getNumDims()].get<
Value>());
763 uniqueExprs.erase(llvm::unique(uniqueExprs), uniqueExprs.end());
771 assert(map.
getNumSymbols() == 0 &&
"expected map without symbols");
774 auto expr = en.value();
776 if (
auto d = dyn_cast<AffineDimExpr>(expr)) {
777 if (exprs[d.getPosition()])
784 for (
auto expr : exprs)
786 seenExprs.push_back(expr);
798 for (
unsigned i : llvm::seq(
unsigned(0), map.
getNumResults())) {
800 if (
auto constExpr = dyn_cast<AffineConstantExpr>(map.
getResult(i))) {
801 assert(constExpr.getValue() == 0 &&
802 "Unexpected constant in projected permutation");
814 unsigned numResults = 0, numDims = 0, numSymbols = 0;
816 numResults += m.getNumResults();
818 results.reserve(numResults);
819 for (
auto m : maps) {
820 for (
auto res : m.getResults())
821 results.push_back(res.shiftSymbols(m.getNumSymbols(), numSymbols));
823 numSymbols += m.getNumSymbols();
824 numDims =
std::max(m.getNumDims(), numDims);
827 maps.front().getContext());
834 template <
typename AffineDimOrSymExpr>
836 const llvm::SmallBitVector &toProject,
838 static_assert(llvm::is_one_of<AffineDimOrSymExpr,
AffineDimExpr,
840 "expected AffineDimExpr or AffineSymbolExpr");
842 constexpr
bool isDim = std::is_same<AffineDimOrSymExpr, AffineDimExpr>::value;
845 replacements.reserve(numDimOrSym);
849 using replace_fn_ty =
855 replace_fn_ty replaceSymbols = [](
AffineExpr e,
859 replace_fn_ty replaceNewDimOrSym = (isDim) ? replaceDims : replaceSymbols;
862 int64_t newNumDimOrSym = 0;
863 for (
unsigned dimOrSym = 0; dimOrSym < numDimOrSym; ++dimOrSym) {
864 if (toProject.test(dimOrSym)) {
868 int64_t newPos = compress ? newNumDimOrSym++ : dimOrSym;
869 replacements.push_back(createNewDimOrSym(newPos, context));
874 resultExprs.push_back(replaceNewDimOrSym(e, replacements));
876 int64_t numDims = (compress && isDim) ? newNumDimOrSym : map.
getNumDims();
877 int64_t numSyms = (compress && !isDim) ? newNumDimOrSym : map.
getNumSymbols();
882 const llvm::SmallBitVector &projectedDimensions,
883 bool compressDimsFlag) {
884 return projectCommonImpl<AffineDimExpr>(map, projectedDimensions,
889 const llvm::SmallBitVector &projectedSymbols,
890 bool compressSymbolsFlag) {
891 return projectCommonImpl<AffineSymbolExpr>(map, projectedSymbols,
892 compressSymbolsFlag);
896 const llvm::SmallBitVector &projectedDimensions,
897 bool compressDimsFlag,
898 bool compressSymbolsFlag) {
899 map =
projectDims(map, projectedDimensions, compressDimsFlag);
900 if (compressSymbolsFlag)
906 unsigned numDims = maps[0].getNumDims();
907 llvm::SmallBitVector numDimsBitVector(numDims,
true);
909 for (
unsigned i = 0; i < numDims; ++i) {
910 if (m.isFunctionOfDim(i))
911 numDimsBitVector.reset(i);
914 return numDimsBitVector;
918 unsigned numSymbols = maps[0].getNumSymbols();
919 llvm::SmallBitVector numSymbolsBitVector(numSymbols,
true);
921 for (
unsigned i = 0; i < numSymbols; ++i) {
922 if (m.isFunctionOfSymbol(i))
923 numSymbolsBitVector.reset(i);
926 return numSymbolsBitVector;
931 const llvm::SmallBitVector &projectedDimensions) {
942 : results(map.getResults()), numDims(map.getNumDims()),
943 numSymbols(map.getNumSymbols()), context(map.
getContext()) {}
950 llvm::append_range(results, map.
getResults());
954 return results[idx].isMultipleOf(factor);
static SmallVector< AffineMap > compressUnusedListImpl(ArrayRef< AffineMap > maps, llvm::function_ref< AffineMap(AffineMap)> compressionFun)
Implementation detail to compress multiple affine maps with a compressionFun that is expected to be e...
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< AffineExprContainer > exprsList, MLIRContext *context)
Creates an affine map each for each list of AffineExpr's in exprsList while inferring the right numbe...
static AffineMap projectCommonImpl(AffineMap map, const llvm::SmallBitVector &toProject, bool compress)
Common implementation to project out dimensions or symbols from an affine map based on the template t...
static MLIRContext * getContext(OpFoldResult val)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
A dimensional identifier appearing in an affine expression.
Base type for affine expression.
AffineExpr replaceDimsAndSymbols(ArrayRef< AffineExpr > dimReplacements, ArrayRef< AffineExpr > symReplacements) const
This method substitutes any uses of dimensions and symbols (e.g.
RetT walk(FnT &&callback) const
Walk all of the AffineExpr's in this expression in postorder.
AffineExprKind getKind() const
Return the classification for this type.
AffineExpr compose(AffineMap map) const
Compose with an AffineMap.
AffineExpr replaceDims(ArrayRef< AffineExpr > dimReplacements) const
Dim-only version of replaceDimsAndSymbols.
MLIRContext * getContext() const
AffineExpr replaceSymbols(ArrayRef< AffineExpr > symReplacements) const
Symbol-only version of replaceDimsAndSymbols.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
int64_t getSingleConstantResult() const
Returns the constant result of this map.
static AffineMap getMinorIdentityMap(unsigned dims, unsigned results, MLIRContext *context)
Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
AffineMap dropResults(ArrayRef< int64_t > positions) const
AffineMap getSliceMap(unsigned start, unsigned length) const
Returns the map consisting of length expressions starting from start.
AffineMap getMajorSubMap(unsigned numResults) const
Returns the map consisting of the most major numResults results.
MLIRContext * getContext() const
bool isMinorIdentity() const
Returns true if this affine map is a minor identity, i.e.
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
bool isConstant() const
Returns true if this affine map has only constant results.
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isSingleConstant() const
Returns true if this affine map is a single result constant function.
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
AffineMap getMinorSubMap(unsigned numResults) const
Returns the map consisting of the most minor numResults results.
uint64_t getLargestKnownDivisorOfMapExprs()
Get the largest known divisor of all map expressions.
constexpr AffineMap()=default
bool isEmpty() const
Returns true if this affine map is an empty map, i.e., () -> ().
std::optional< unsigned > getResultPosition(AffineExpr input) const
Extracts the first result position where input dimension resides.
unsigned getNumSymbols() const
bool isMinorIdentityWithBroadcasting(SmallVectorImpl< unsigned > *broadcastedDims=nullptr) const
Returns true if this affine map is a minor identity up to broadcasted dimensions which are indicated ...
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
SmallVector< int64_t > getConstantResults() const
Returns the constant results of this map.
bool isPermutationOfMinorIdentityWithBroadcasting(SmallVectorImpl< unsigned > &permutedDims) const
Return true if this affine map can be converted to a minor identity with broadcast by doing a permute...
bool isSymbolIdentity() const
Returns true if this affine map is an identity affine map on the symbol identifiers.
unsigned getNumResults() const
AffineMap replaceDimsAndSymbols(ArrayRef< AffineExpr > dimReplacements, ArrayRef< AffineExpr > symReplacements, unsigned numResultDims, unsigned numResultSyms) const
This method substitutes any uses of dimensions and symbols (e.g.
unsigned getNumInputs() const
AffineExpr getResult(unsigned idx) const
static AffineMap getFilteredIdentityMap(MLIRContext *ctx, unsigned numDims, llvm::function_ref< bool(AffineDimExpr)> keepDimFilter)
Returns an identity affine map with numDims input dimensions and filtered results using keepDimFilter...
AffineMap replace(AffineExpr expr, AffineExpr replacement, unsigned numResultDims, unsigned numResultSyms) const
Sparse replace method.
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
SmallVector< unsigned > getBroadcastDims() const
Returns the list of broadcast dimensions (i.e.
void walkExprs(llvm::function_ref< void(AffineExpr)> callback) const
Walk all of the AffineExpr's in this mapping.
AffineMap partialConstantFold(ArrayRef< Attribute > operandConstants, SmallVectorImpl< int64_t > *results=nullptr, bool *hasPoison=nullptr) const
Propagates the constant operands into this affine map.
static AffineMap getConstantMap(int64_t val, MLIRContext *context)
Returns a single constant result affine map.
static AffineMap getMultiDimMapWithTargets(unsigned numDims, ArrayRef< unsigned > targets, MLIRContext *context)
Returns an affine map with numDims input dimensions and results specified by targets.
AffineMap getSubMap(ArrayRef< unsigned > resultPos) const
Returns the map consisting of the resultPos subset.
LogicalResult constantFold(ArrayRef< Attribute > operandConstants, SmallVectorImpl< Attribute > &results, bool *hasPoison=nullptr) const
Folds the results of the application of an affine map on the provided operands to a constant if possi...
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
bool isIdentity() const
Returns true if this affine map is an identity affine map.
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
A symbolic identifier appearing in an affine expression.
Attributes are known-constant values of operations.
This class is a general helper class for creating context-global objects like types,...
AffineExpr getAffineSymbolExpr(unsigned position)
AffineExpr getAffineConstantExpr(int64_t constant)
AffineExpr getAffineDimExpr(unsigned position)
MLIRContext is the top-level object for a collection of MLIR operations.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
AffineMap simplifyAffineMap(AffineMap map)
Simplifies an affine map by simplifying its underlying AffineExpr results.
AffineMap expandDimsToRank(AffineMap map, int64_t rank, const llvm::SmallBitVector &projectedDimensions)
Expand map to operate on rank dims while projecting out the dims in projectedDimensions.
AffineMap removeDuplicateExprs(AffineMap map)
Returns a map with the same dimension and symbol count as map, but whose results are the unique affin...
llvm::SmallBitVector getUnusedSymbolsBitVector(ArrayRef< AffineMap > maps)
AffineMap inverseAndBroadcastProjectedPermutation(AffineMap map)
Return the reverse map of a projected permutation where the projected dimensions are transformed into...
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
AffineMap concatAffineMaps(ArrayRef< AffineMap > maps)
Concatenates a list of maps into a single AffineMap, stepping over potentially empty maps.
@ CeilDiv
RHS of ceildiv is always a constant or a symbolic expression.
@ Mul
RHS of mul is always a constant or a symbolic expression.
@ Mod
RHS of mod is always a constant or a symbolic expression with a positive value.
@ DimId
Dimensional identifier.
@ FloorDiv
RHS of floordiv is always a constant or a symbolic expression.
@ Constant
Constant integer.
@ SymbolId
Symbolic identifier.
AffineMap compressSymbols(AffineMap map, const llvm::SmallBitVector &unusedSymbols)
Drop the symbols that are listed in unusedSymbols.
static void getMaxDimAndSymbol(ArrayRef< AffineExprContainer > exprsList, int64_t &maxDim, int64_t &maxSym)
Calculates maximum dimension and symbol positions from the expressions in exprsLists and stores them ...
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
AffineMap compressDims(AffineMap map, const llvm::SmallBitVector &unusedDims)
Drop the dims that are listed in unusedDims.
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
AffineMap getProjectedMap(AffineMap map, const llvm::SmallBitVector &projectedDimensions, bool compressDimsFlag=true, bool compressSymbolsFlag=true)
Calls projectDims(map, projectedDimensions, compressDimsFlag).
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::SmallBitVector getUnusedDimsBitVector(ArrayRef< AffineMap > maps)
AffineExpr simplifyAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols)
Simplify an affine expression by flattening and some amount of simple analysis.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
AffineMap projectDims(AffineMap map, const llvm::SmallBitVector &projectedDimensions, bool compressDimsFlag=false)
Returns the map that results from projecting out the dimensions specified in projectedDimensions.
AffineMap compressUnusedSymbols(AffineMap map)
Drop the symbols that are not used.
AffineMap projectSymbols(AffineMap map, const llvm::SmallBitVector &projectedSymbols, bool compressSymbolsFlag=false)
Symbol counterpart of projectDims.
AffineMap foldAttributesIntoMap(Builder &b, AffineMap map, ArrayRef< OpFoldResult > operands, SmallVector< Value > &remainingValues)
Fold all attributes among the given operands into the affine map.
AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context)
void reset(AffineMap map)
Resets this MutableAffineMap with 'map'.
MutableAffineMap()=default
AffineMap getAffineMap() const
Get the AffineMap corresponding to this MutableAffineMap.
AffineExpr getResult(unsigned idx) const
bool isMultipleOf(unsigned idx, int64_t factor) const
Returns true if the idx'th result expression is a multiple of factor.
unsigned getNumResults() const
void simplify()
Simplify the (result) expressions in this map using analysis (used by.
ArrayRef< AffineExpr > results() const
The affine expressions for this (multi-dimensional) map.