17 #include "llvm/ADT/STLExtras.h"
18 #include "llvm/ADT/SmallBitVector.h"
19 #include "llvm/ADT/SmallSet.h"
20 #include "llvm/ADT/SmallVector.h"
21 #include "llvm/ADT/StringRef.h"
22 #include "llvm/Support/raw_ostream.h"
26 #include <type_traits>
36 class AffineExprConstantFolder {
39 : numDims(numDims), operandConsts(operandConsts) {}
44 if (
auto result = constantFoldImpl(expr))
49 bool hasPoison()
const {
return hasPoison_; }
52 std::optional<int64_t> constantFoldImpl(
AffineExpr expr) {
55 return constantFoldBinExpr(
56 expr, [](int64_t lhs, int64_t rhs) {
return lhs + rhs; });
58 return constantFoldBinExpr(
59 expr, [](int64_t lhs, int64_t rhs) {
return lhs * rhs; });
61 return constantFoldBinExpr(
62 expr, [
this](int64_t lhs, int64_t rhs) -> std::optional<int64_t> {
70 return constantFoldBinExpr(
71 expr, [
this](int64_t lhs, int64_t rhs) -> std::optional<int64_t> {
79 return constantFoldBinExpr(
80 expr, [
this](int64_t lhs, int64_t rhs) -> std::optional<int64_t> {
88 return cast<AffineConstantExpr>(expr).getValue();
90 if (
auto attr = llvm::dyn_cast_or_null<IntegerAttr>(
91 operandConsts[cast<AffineDimExpr>(expr).getPosition()]))
95 if (
auto attr = llvm::dyn_cast_or_null<IntegerAttr>(
96 operandConsts[numDims +
97 cast<AffineSymbolExpr>(expr).getPosition()]))
101 llvm_unreachable(
"Unknown AffineExpr");
105 std::optional<int64_t> constantFoldBinExpr(
108 auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
109 if (
auto lhs = constantFoldImpl(binOpExpr.getLHS()))
110 if (
auto rhs = constantFoldImpl(binOpExpr.getRHS()))
111 return op(*lhs, *rhs);
119 bool hasPoison_{
false};
134 assert(dims >= results &&
"Dimension mismatch");
145 llvm::SmallBitVector dropDimResults(numDims);
146 for (
auto [idx, resultExpr] :
llvm::enumerate(identityMap.getResults()))
147 dropDimResults[idx] = !keepDimFilter(cast<AffineDimExpr>(resultExpr));
149 return identityMap.dropResults(dropDimResults);
163 broadcastedDims->clear();
168 unsigned resIdx = idxAndExpr.index();
170 if (
auto constExpr = dyn_cast<AffineConstantExpr>(expr)) {
172 if (constExpr.getValue() != 0)
175 broadcastedDims->push_back(resIdx);
176 }
else if (
auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
178 if (dimExpr.getPosition() != suffixStart + resIdx)
202 unsigned projectionStart =
204 permutedDims.clear();
210 unsigned leadingBroadcast =
215 unsigned resIdx = idxAndExpr.index();
219 if (
auto constExpr = dyn_cast<AffineConstantExpr>(expr)) {
220 if (constExpr.getValue() != 0)
222 broadcastDims.push_back(resIdx);
223 }
else if (
auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
224 if (dimExpr.getPosition() < projectionStart)
226 unsigned newPosition =
227 dimExpr.getPosition() - projectionStart + leadingBroadcast;
228 permutedDims[resIdx] = newPosition;
229 dimFound[newPosition] =
true;
238 for (
auto dim : broadcastDims) {
239 while (pos < dimFound.size() && dimFound[pos]) {
242 permutedDims[dim] = pos++;
250 assert(!permutation.empty() &&
251 "Cannot create permutation map from empty permutation vector");
252 const auto *m = llvm::max_element(permutation);
254 assert(permutationMap.isPermutation() &&
"Invalid permutation vector");
255 return permutationMap;
260 permutation, [](int64_t i) {
return static_cast<unsigned>(i); });
268 for (
unsigned t : targets)
279 template <
typename AffineExprContainer>
283 if (exprsList.empty())
285 int64_t maxDim = -1, maxSym = -1;
288 maps.reserve(exprsList.size());
289 for (
const auto &exprs : exprsList)
291 maxSym + 1, exprs, context));
310 uint64_t thisGcd = resultExpr.getLargestKnownDivisor();
321 dimExprs.reserve(numDims);
322 for (
unsigned i = 0; i < numDims; ++i)
324 return get(numDims, 0, dimExprs, context);
333 for (
unsigned i = 0, numDims =
getNumDims(); i < numDims; ++i) {
334 auto expr = dyn_cast<AffineDimExpr>(results[i]);
335 if (!expr || expr.getPosition() != i)
345 for (
unsigned i = 0, numSymbols =
getNumSymbols(); i < numSymbols; ++i) {
346 auto expr = dyn_cast<AffineDimExpr>(results[i]);
347 if (!expr || expr.getPosition() != i)
363 return isa<AffineConstantExpr>(expr);
369 return cast<AffineConstantExpr>(
getResult(0)).getValue();
373 assert(
isConstant() &&
"map must have only constant results");
376 result.emplace_back(cast<AffineConstantExpr>(expr).getValue());
381 assert(map &&
"uninitialized map storage");
385 assert(map &&
"uninitialized map storage");
390 assert(map &&
"uninitialized map storage");
394 assert(map &&
"uninitialized map storage");
402 return cast<AffineDimExpr>(
getResult(idx)).getPosition();
406 if (!isa<AffineDimExpr>(input))
409 for (
unsigned i = 0, numResults =
getNumResults(); i < numResults; i++) {
422 bool *hasPoison)
const {
429 if (integers.empty())
432 auto range = llvm::map_range(integers, [
this](int64_t i) {
435 results.append(range.begin(), range.end());
441 bool *hasPoison)
const {
445 AffineExprConstantFolder exprFolder(
getNumDims(), operandConstants);
450 auto folded = exprFolder.constantFold(expr);
451 if (exprFolder.hasPoison() && hasPoison) {
461 results->push_back(folded.getInt());
463 exprs.push_back(expr);
488 unsigned numResultDims,
489 unsigned numResultSyms)
const {
495 return get(numResultDims, numResultSyms, results,
getContext());
502 unsigned numResultDims,
503 unsigned numResultSyms)
const {
507 newResults.push_back(e.replace(expr, replacement));
515 unsigned numResultDims,
516 unsigned numResultSyms)
const {
520 newResults.push_back(e.replace(map));
529 newResults.push_back(e.replace(map));
534 auto exprs = llvm::to_vector<4>(
getResults());
536 for (
auto pos = positions.find_last(); pos != -1;
537 pos = positions.find_prev(pos))
538 exprs.erase(exprs.begin() + pos);
547 unsigned numSymbols = numSymbolsThisMap + map.
getNumSymbols();
549 for (
unsigned idx = 0; idx < numDims; ++idx) {
553 for (
unsigned idx = numSymbolsThisMap; idx < numSymbols; ++idx) {
554 newSymbols[idx - numSymbolsThisMap] =
562 exprs.push_back(expr.
compose(newMap));
569 exprs.reserve(values.size());
571 for (
auto v : values)
575 res.reserve(resMap.getNumResults());
576 for (
auto e : resMap.getResults())
577 res.push_back(cast<AffineConstantExpr>(e).getValue());
596 if (
auto dim = dyn_cast<AffineDimExpr>(expr)) {
597 if (seen[dim.getPosition()])
599 seen[dim.getPosition()] =
true;
601 auto constExpr = dyn_cast<AffineConstantExpr>(expr);
602 if (!allowZeroInResults || !constExpr || constExpr.getValue() != 0)
619 exprs.reserve(resultPos.size());
620 for (
auto idx : resultPos)
656 allExprs.reserve(maps.size() * maps.front().getNumResults());
657 unsigned numDims = maps.front().getNumDims(),
658 numSymbols = maps.front().getNumSymbols();
659 for (
auto m : maps) {
660 assert(numDims == m.getNumDims() && numSymbols == m.getNumSymbols() &&
661 "expected maps with same num dims and symbols");
662 llvm::append_range(allExprs, m.getResults());
665 AffineMap::get(numDims, numSymbols, allExprs, maps.front().getContext()));
666 unsigned unifiedNumDims = unifiedMap.
getNumDims(),
670 res.reserve(maps.size());
671 for (
auto m : maps) {
673 unifiedResults.take_front(m.getNumResults()),
675 unifiedResults = unifiedResults.drop_front(m.getNumResults());
681 const llvm::SmallBitVector &unusedDims) {
695 const llvm::SmallBitVector &unusedSymbols) {
713 for (int64_t i = 0; i < map.
getNumDims(); ++i) {
714 if (
auto attr = operands[i].dyn_cast<Attribute>()) {
715 dimReplacements.push_back(
719 remainingValues.push_back(operands[i].get<Value>());
722 int64_t numSymbols = 0;
725 symReplacements.push_back(
729 remainingValues.push_back(operands[i + map.
getNumDims()].get<
Value>());
749 uniqueExprs.erase(std::unique(uniqueExprs.begin(), uniqueExprs.end()),
758 assert(map.
getNumSymbols() == 0 &&
"expected map without symbols");
761 auto expr = en.value();
763 if (
auto d = dyn_cast<AffineDimExpr>(expr)) {
764 if (exprs[d.getPosition()])
771 for (
auto expr : exprs)
773 seenExprs.push_back(expr);
785 for (
unsigned i : llvm::seq(
unsigned(0), map.
getNumResults())) {
787 if (
auto constExpr = dyn_cast<AffineConstantExpr>(map.
getResult(i))) {
788 assert(constExpr.getValue() == 0 &&
789 "Unexpected constant in projected permutation");
801 unsigned numResults = 0, numDims = 0, numSymbols = 0;
805 results.reserve(numResults);
806 for (
auto m : maps) {
807 for (
auto res : m.getResults())
808 results.push_back(res.shiftSymbols(m.getNumSymbols(), numSymbols));
810 numSymbols += m.getNumSymbols();
811 numDims =
std::max(m.getNumDims(), numDims);
814 maps.front().getContext());
821 template <
typename AffineDimOrSymExpr>
823 const llvm::SmallBitVector &toProject,
825 static_assert(llvm::is_one_of<AffineDimOrSymExpr,
AffineDimExpr,
827 "expected AffineDimExpr or AffineSymbolExpr");
829 constexpr
bool isDim = std::is_same<AffineDimOrSymExpr, AffineDimExpr>::value;
832 replacements.reserve(numDimOrSym);
836 using replace_fn_ty =
842 replace_fn_ty replaceSymbols = [](
AffineExpr e,
846 replace_fn_ty replaceNewDimOrSym = (isDim) ? replaceDims : replaceSymbols;
849 int64_t newNumDimOrSym = 0;
850 for (
unsigned dimOrSym = 0; dimOrSym < numDimOrSym; ++dimOrSym) {
851 if (toProject.test(dimOrSym)) {
855 int64_t newPos = compress ? newNumDimOrSym++ : dimOrSym;
856 replacements.push_back(createNewDimOrSym(newPos, context));
861 resultExprs.push_back(replaceNewDimOrSym(e, replacements));
863 int64_t numDims = (compress && isDim) ? newNumDimOrSym : map.
getNumDims();
864 int64_t numSyms = (compress && !isDim) ? newNumDimOrSym : map.
getNumSymbols();
869 const llvm::SmallBitVector &projectedDimensions,
870 bool compressDimsFlag) {
871 return projectCommonImpl<AffineDimExpr>(map, projectedDimensions,
876 const llvm::SmallBitVector &projectedSymbols,
877 bool compressSymbolsFlag) {
878 return projectCommonImpl<AffineSymbolExpr>(map, projectedSymbols,
879 compressSymbolsFlag);
883 const llvm::SmallBitVector &projectedDimensions,
884 bool compressDimsFlag,
885 bool compressSymbolsFlag) {
886 map =
projectDims(map, projectedDimensions, compressDimsFlag);
887 if (compressSymbolsFlag)
893 unsigned numDims = maps[0].getNumDims();
894 llvm::SmallBitVector numDimsBitVector(numDims,
true);
896 for (
unsigned i = 0; i < numDims; ++i) {
897 if (m.isFunctionOfDim(i))
898 numDimsBitVector.reset(i);
901 return numDimsBitVector;
905 unsigned numSymbols = maps[0].getNumSymbols();
906 llvm::SmallBitVector numSymbolsBitVector(numSymbols,
true);
908 for (
unsigned i = 0; i < numSymbols; ++i) {
909 if (m.isFunctionOfSymbol(i))
910 numSymbolsBitVector.reset(i);
913 return numSymbolsBitVector;
918 const llvm::SmallBitVector &projectedDimensions) {
929 : results(map.getResults().begin(), map.getResults().end()),
930 numDims(map.getNumDims()), numSymbols(map.getNumSymbols()),
938 llvm::append_range(results, map.
getResults());
942 return results[idx].isMultipleOf(factor);
static SmallVector< AffineMap > compressUnusedListImpl(ArrayRef< AffineMap > maps, llvm::function_ref< AffineMap(AffineMap)> compressionFun)
Implementation detail to compress multiple affine maps with a compressionFun that is expected to be e...
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< AffineExprContainer > exprsList, MLIRContext *context)
Creates an affine map each for each list of AffineExpr's in exprsList while inferring the right numbe...
static AffineMap projectCommonImpl(AffineMap map, const llvm::SmallBitVector &toProject, bool compress)
Common implementation to project out dimensions or symbols from an affine map based on the template t...
static MLIRContext * getContext(OpFoldResult val)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
A dimensional identifier appearing in an affine expression.
Base type for affine expression.
AffineExpr replaceDimsAndSymbols(ArrayRef< AffineExpr > dimReplacements, ArrayRef< AffineExpr > symReplacements) const
This method substitutes any uses of dimensions and symbols (e.g.
RetT walk(FnT &&callback) const
Walk all of the AffineExpr's in this expression in postorder.
AffineExprKind getKind() const
Return the classification for this type.
AffineExpr compose(AffineMap map) const
Compose with an AffineMap.
AffineExpr replaceDims(ArrayRef< AffineExpr > dimReplacements) const
Dim-only version of replaceDimsAndSymbols.
MLIRContext * getContext() const
AffineExpr replaceSymbols(ArrayRef< AffineExpr > symReplacements) const
Symbol-only version of replaceDimsAndSymbols.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
int64_t getSingleConstantResult() const
Returns the constant result of this map.
static AffineMap getMinorIdentityMap(unsigned dims, unsigned results, MLIRContext *context)
Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
AffineMap dropResults(ArrayRef< int64_t > positions) const
AffineMap getSliceMap(unsigned start, unsigned length) const
Returns the map consisting of length expressions starting from start.
AffineMap getMajorSubMap(unsigned numResults) const
Returns the map consisting of the most major numResults results.
MLIRContext * getContext() const
bool isMinorIdentity() const
Returns true if this affine map is a minor identity, i.e.
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
bool isConstant() const
Returns true if this affine map has only constant results.
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isSingleConstant() const
Returns true if this affine map is a single result constant function.
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
AffineMap getMinorSubMap(unsigned numResults) const
Returns the map consisting of the most minor numResults results.
uint64_t getLargestKnownDivisorOfMapExprs()
Get the largest known divisor of all map expressions.
constexpr AffineMap()=default
bool isEmpty() const
Returns true if this affine map is an empty map, i.e., () -> ().
std::optional< unsigned > getResultPosition(AffineExpr input) const
Extracts the first result position where input dimension resides.
unsigned getNumSymbols() const
bool isMinorIdentityWithBroadcasting(SmallVectorImpl< unsigned > *broadcastedDims=nullptr) const
Returns true if this affine map is a minor identity up to broadcasted dimensions which are indicated ...
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
SmallVector< int64_t > getConstantResults() const
Returns the constant results of this map.
bool isPermutationOfMinorIdentityWithBroadcasting(SmallVectorImpl< unsigned > &permutedDims) const
Return true if this affine map can be converted to a minor identity with broadcast by doing a permute...
bool isSymbolIdentity() const
Returns true if this affine map is an identity affine map on the symbol identifiers.
unsigned getNumResults() const
AffineMap replaceDimsAndSymbols(ArrayRef< AffineExpr > dimReplacements, ArrayRef< AffineExpr > symReplacements, unsigned numResultDims, unsigned numResultSyms) const
This method substitutes any uses of dimensions and symbols (e.g.
unsigned getNumInputs() const
AffineExpr getResult(unsigned idx) const
static AffineMap getFilteredIdentityMap(MLIRContext *ctx, unsigned numDims, llvm::function_ref< bool(AffineDimExpr)> keepDimFilter)
Returns an identity affine map with numDims input dimensions and filtered results using keepDimFilter...
AffineMap replace(AffineExpr expr, AffineExpr replacement, unsigned numResultDims, unsigned numResultSyms) const
Sparse replace method.
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
void walkExprs(llvm::function_ref< void(AffineExpr)> callback) const
Walk all of the AffineExpr's in this mapping.
AffineMap partialConstantFold(ArrayRef< Attribute > operandConstants, SmallVectorImpl< int64_t > *results=nullptr, bool *hasPoison=nullptr) const
Propagates the constant operands into this affine map.
static AffineMap getConstantMap(int64_t val, MLIRContext *context)
Returns a single constant result affine map.
static AffineMap getMultiDimMapWithTargets(unsigned numDims, ArrayRef< unsigned > targets, MLIRContext *context)
Returns an affine map with numDims input dimensions and results specified by targets.
AffineMap getSubMap(ArrayRef< unsigned > resultPos) const
Returns the map consisting of the resultPos subset.
LogicalResult constantFold(ArrayRef< Attribute > operandConstants, SmallVectorImpl< Attribute > &results, bool *hasPoison=nullptr) const
Folds the results of the application of an affine map on the provided operands to a constant if possi...
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
bool isIdentity() const
Returns true if this affine map is an identity affine map.
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
A symbolic identifier appearing in an affine expression.
Attributes are known-constant values of operations.
This class is a general helper class for creating context-global objects like types,...
AffineExpr getAffineSymbolExpr(unsigned position)
AffineExpr getAffineConstantExpr(int64_t constant)
AffineExpr getAffineDimExpr(unsigned position)
MLIRContext is the top-level object for a collection of MLIR operations.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt gcd(const MPInt &a, const MPInt &b)
Include the generated interface declarations.
AffineMap simplifyAffineMap(AffineMap map)
Simplifies an affine map by simplifying its underlying AffineExpr results.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
AffineMap expandDimsToRank(AffineMap map, int64_t rank, const llvm::SmallBitVector &projectedDimensions)
Expand map to operate on rank dims while projecting out the dims in projectedDimensions.
AffineMap removeDuplicateExprs(AffineMap map)
Returns a map with the same dimension and symbol count as map, but whose results are the unique affin...
llvm::SmallBitVector getUnusedSymbolsBitVector(ArrayRef< AffineMap > maps)
AffineMap inverseAndBroadcastProjectedPermutation(AffineMap map)
Return the reverse map of a projected permutation where the projected dimensions are transformed into...
int64_t floorDiv(int64_t lhs, int64_t rhs)
Returns the result of MLIR's floordiv operation on constants.
int64_t ceilDiv(int64_t lhs, int64_t rhs)
Returns the result of MLIR's ceildiv operation on constants.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
AffineMap concatAffineMaps(ArrayRef< AffineMap > maps)
Concatenates a list of maps into a single AffineMap, stepping over potentially empty maps.
@ CeilDiv
RHS of ceildiv is always a constant or a symbolic expression.
@ Mul
RHS of mul is always a constant or a symbolic expression.
@ Mod
RHS of mod is always a constant or a symbolic expression with a positive value.
@ DimId
Dimensional identifier.
@ FloorDiv
RHS of floordiv is always a constant or a symbolic expression.
@ Constant
Constant integer.
@ SymbolId
Symbolic identifier.
AffineMap compressSymbols(AffineMap map, const llvm::SmallBitVector &unusedSymbols)
Drop the symbols that are listed in unusedSymbols.
static void getMaxDimAndSymbol(ArrayRef< AffineExprContainer > exprsList, int64_t &maxDim, int64_t &maxSym)
Calculates maximum dimension and symbol positions from the expressions in exprsLists and stores them ...
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
AffineMap compressDims(AffineMap map, const llvm::SmallBitVector &unusedDims)
Drop the dims that are listed in unusedDims.
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
AffineMap getProjectedMap(AffineMap map, const llvm::SmallBitVector &projectedDimensions, bool compressDimsFlag=true, bool compressSymbolsFlag=true)
Calls projectDims(map, projectedDimensions, compressDimsFlag).
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::SmallBitVector getUnusedDimsBitVector(ArrayRef< AffineMap > maps)
AffineExpr simplifyAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols)
Simplify an affine expression by flattening and some amount of simple analysis.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
AffineMap projectDims(AffineMap map, const llvm::SmallBitVector &projectedDimensions, bool compressDimsFlag=false)
Returns the map that results from projecting out the dimensions specified in projectedDimensions.
AffineMap compressUnusedSymbols(AffineMap map)
Drop the symbols that are not used.
AffineMap projectSymbols(AffineMap map, const llvm::SmallBitVector &projectedSymbols, bool compressSymbolsFlag=false)
Symbol counterpart of projectDims.
AffineMap foldAttributesIntoMap(Builder &b, AffineMap map, ArrayRef< OpFoldResult > operands, SmallVector< Value > &remainingValues)
Fold all attributes among the given operands into the affine map.
AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context)
int64_t mod(int64_t lhs, int64_t rhs)
Returns MLIR's mod operation on constants.
This class represents an efficient way to signal success or failure.
void reset(AffineMap map)
Resets this MutableAffineMap with 'map'.
MutableAffineMap()=default
AffineMap getAffineMap() const
Get the AffineMap corresponding to this MutableAffineMap.
AffineExpr getResult(unsigned idx) const
bool isMultipleOf(unsigned idx, int64_t factor) const
Returns true if the idx'th result expression is a multiple of factor.
unsigned getNumResults() const
void simplify()
Simplify the (result) expressions in this map using analysis (used by.
ArrayRef< AffineExpr > results() const
The affine expressions for this (multi-dimensional) map.