18#include "llvm/ADT/STLExtras.h"
24 if (!type.hasStaticShape())
29 if (failed(type.getStridesAndOffset(strides, offset)))
35 int64_t curDim = strides.size() - 1;
37 while (curDim >= 0 && strides[curDim] == runningStride) {
38 runningStride *= type.getDimSize(curDim);
43 while (curDim >= 0 && type.getDimSize(curDim) == 1) {
56 unsigned sourceRank = sizes.size();
57 assert(sizes.size() == strides.size() &&
58 "expected as many sizes as strides for a memref");
62 assert(indicesVec.size() == strides.size() &&
63 "expected as many indices as rank of memref");
72 for (
unsigned i = 0; i < sourceRank; ++i) {
73 unsigned offsetIdx = 2 * i;
74 addMulMap = addMulMap + symbols[offsetIdx] * symbols[offsetIdx + 1];
75 offsetValues[offsetIdx] = indicesVec[i];
76 offsetValues[offsetIdx + 1] = strides[i];
79 int64_t scaler = dstBits / srcBits;
81 builder, loc, addMulMap.
floorDiv(scaler), offsetValues);
83 size_t symbolIndex = 0;
86 for (
unsigned i = 0; i < sourceRank; ++i) {
87 AffineExpr strideExpr = symbols[symbolIndex++];
88 values.push_back(strides[i]);
90 values.push_back(sizes[i]);
98 0, symbolIndex, productExpressions,
107 builder, loc, s0.
floorDiv(scaler), {offset});
110 builder, loc, addMulMap % scaler, offsetValues);
112 return {{adjustBaseOffset, linearizedSize, intraVectorOffset},
122 if (!sizes.empty()) {
128 builder, loc, s0 * s1,
134 std::tie(linearizedMemRefInfo, std::ignore) =
138 return linearizedMemRefInfo;
146 std::vector<Operation *> opUses;
152 if (isa<memref::DeallocOp>(useOp) ||
156 opUses.push_back(useOp);
161 llvm::append_range(uses, opUses);
166 std::vector<Operation *> opToErase;
168 std::vector<Operation *> candidates;
169 if (isa<memref::AllocOp, memref::AllocaOp>(op) &&
171 llvm::append_range(opToErase, candidates);
172 opToErase.push_back(op);
188 for (
int64_t r = strides.size() - 1; r > 0; --r) {
190 builder, loc, s0 * s1, {strides[r], sizes[r]});
203 while (
auto *op = source.getDefiningOp()) {
204 if (
auto subViewOp = dyn_cast<memref::SubViewOp>(op);
205 subViewOp && subViewOp.hasZeroOffset() && subViewOp.hasUnitStride()) {
208 source = cast<MemrefValue>(subViewOp.getSource());
209 }
else if (
auto castOp = dyn_cast<memref::CastOp>(op)) {
211 source = castOp.getSource();
220 while (
auto *op = source.getDefiningOp()) {
221 if (
auto viewLike = dyn_cast<ViewLikeOpInterface>(op)) {
222 if (source == viewLike.getViewDest()) {
223 source = cast<MemrefValue>(viewLike.getViewSource());
233 memref::ExpandShapeOp expandShapeOp,
236 bool startsInbounds) {
242 assert(!group.empty() &&
"association indices groups cannot be empty");
243 int64_t groupSize = group.size();
244 if (groupSize == 1) {
245 sourceIndices.push_back(
indices[group[0]]);
249 llvm::map_to_vector(group, [&](
int64_t d) {
return destShape[d]; });
251 llvm::map_to_vector(group, [&](
int64_t d) {
return indices[d]; });
252 Value collapsedIndex = affine::AffineLinearizeIndexOp::create(
253 rewriter, loc, groupIndices, groupBasis, startsInbounds);
254 sourceIndices.push_back(collapsedIndex);
259 memref::CollapseShapeOp collapseShapeOp,
263 auto metadata = memref::ExtractStridedMetadataOp::create(
264 rewriter, loc, collapseShapeOp.getSrc());
266 for (
auto [
index, group] :
267 llvm::zip(
indices, collapseShapeOp.getReassociationIndices())) {
268 assert(!group.empty() &&
"association indices groups cannot be empty");
269 int64_t groupSize = group.size();
271 if (groupSize == 1) {
272 sourceIndices.push_back(
index);
277 llvm::map_to_vector(group, [&](
int64_t d) {
return sourceSizes[d]; });
278 auto delinearize = affine::AffineDelinearizeIndexOp::create(
279 rewriter, loc,
index, basis,
true);
280 llvm::append_range(sourceIndices,
delinearize.getResults());
282 if (collapseShapeOp.getReassociationIndices().empty()) {
285 cast<MemRefType>(collapseShapeOp.getViewSource().getType()).getRank();
288 for (
int64_t i = 0; i < srcRank; i++) {
289 sourceIndices.push_back(
298 if (!subViewOp.hasZeroOffset() || !subViewOp.hasUnitStride())
301 MemRefType srcType = subViewOp.getSourceType();
302 MemRefType resType = subViewOp.getType();
303 unsigned srcRank = srcType.getRank();
304 unsigned resRank = resType.getRank();
305 if (srcRank <= resRank ||
indices.size() != resRank)
308 auto droppedDims = subViewOp.getDroppedDims();
309 if (droppedDims.none() || droppedDims.count() != srcRank - resRank)
312 auto mixedSizes = subViewOp.getMixedSizes();
313 if (mixedSizes.size() != srcRank)
316 unsigned resultDim = 0;
317 for (
unsigned sourceDim = 0; sourceDim < srcRank; ++sourceDim) {
318 if (droppedDims.test(sourceDim)) {
320 if (!sizeCst || *sizeCst != 1)
322 sourceIndices.push_back(
326 if (resultDim >=
indices.size())
328 sourceIndices.push_back(
indices[resultDim++]);
330 if (resultDim !=
indices.size())
static int64_t product(ArrayRef< int64_t > vals)
Base type for affine expression.
AffineExpr floorDiv(uint64_t v) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
IntegerAttr getIndexAttr(int64_t value)
AffineExpr getAffineConstantExpr(int64_t constant)
AffineMap getConstantAffineMap(int64_t val)
Returns a single constant result affine map with 0 dimensions and 0 symbols.
MLIRContext * getContext() const
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
This class provides the API for ops that are known to be terminators.
Operation is the basic unit of execution within MLIR.
bool mightHaveTrait()
Returns true if the operation might have the provided trait.
unsigned getNumRegions()
Returns the number of regions held by this operation.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
use_range getUses()
Returns a range of all uses, which is useful for iterating over all uses.
unsigned getNumResults()
Return the number of results held by this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
OpFoldResult makeComposedFoldedAffineMax(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a maximum across the results of applying map to operands,...
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
std::pair< LinearizedMemRefInfo, OpFoldResult > getLinearizedMemRefOffsetAndSize(OpBuilder &builder, Location loc, int srcBits, int dstBits, OpFoldResult offset, ArrayRef< OpFoldResult > sizes, ArrayRef< OpFoldResult > strides, ArrayRef< OpFoldResult > indices={}, LinearizedDivKind sizeDivKind=LinearizedDivKind::Floor)
static bool resultIsNotRead(Operation *op, std::vector< Operation * > &uses)
Returns true if all the uses of op are not read/load.
MemrefValue skipFullyAliasingOperations(MemrefValue source)
Walk up the source chain until an operation that changes/defines the view of memory is found (i....
void eraseDeadAllocAndStores(RewriterBase &rewriter, Operation *parentOp)
Track temporary allocations that are never read from.
MemrefValue skipViewLikeOps(MemrefValue source)
Walk up the source chain until we find an operation that is not a view of the source memref (i....
bool isStaticShapeAndContiguousRowMajor(MemRefType type)
Returns true, if the memref type has static shapes and represents a contiguous chunk of memory.
LinearizedDivKind
Controls how the per-dimension contribution to linearizedSize is divided by dstBits / srcBits when sc...
LogicalResult resolveSourceIndicesRankReducingSubview(Location loc, OpBuilder &b, memref::SubViewOp subViewOp, ValueRange indices, SmallVectorImpl< Value > &sourceIndices)
Given the 'indices' of a load/store operation where the memref is a result of a rank-reducing full su...
static SmallVector< OpFoldResult > computeSuffixProductIRBlockImpl(Location loc, OpBuilder &builder, ArrayRef< OpFoldResult > sizes, OpFoldResult unit)
void resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter, memref::CollapseShapeOp collapseShapeOp, ValueRange indices, SmallVectorImpl< Value > &sourceIndices)
Given the 'indices' of a load/store operation where the memref is a result of a collapse_shape op,...
SmallVector< OpFoldResult > computeSuffixProductIRBlock(Location loc, OpBuilder &builder, ArrayRef< OpFoldResult > sizes)
Given a set of sizes, return the suffix product.
void resolveSourceIndicesExpandShape(Location loc, PatternRewriter &rewriter, memref::ExpandShapeOp expandShapeOp, ValueRange indices, SmallVectorImpl< Value > &sourceIndices, bool startsInbounds)
Given the 'indices' of a load/store operation where the memref is a result of a expand_shape op,...
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
TypedValue< BaseMemRefType > MemrefValue
A value with a memref type.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
bool hasEffect(Operation *op)
Returns "true" if op has an effect of type EffectTy.
void bindSymbolsList(MLIRContext *ctx, MutableArrayRef< AffineExprTy > exprs)
For a memref with offset, sizes and strides, returns the offset, size, and potentially the size padde...