18#include "llvm/ADT/STLExtras.h"
24 if (!type.hasStaticShape())
33 if (failed(type.getStridesAndOffset(strides, offset)))
41 while (curDim >= 0 && strides[curDim] == runningStride) {
42 runningStride *= type.getDimSize(curDim);
47 while (curDim >= 0 && type.getDimSize(curDim) == 1) {
60 unsigned sourceRank = sizes.size();
61 assert(sizes.size() == strides.size() &&
62 "expected as many sizes as strides for a memref");
66 assert(indicesVec.size() == strides.size() &&
67 "expected as many indices as rank of memref");
76 for (
unsigned i = 0; i < sourceRank; ++i) {
77 unsigned offsetIdx = 2 * i;
78 addMulMap = addMulMap + symbols[offsetIdx] * symbols[offsetIdx + 1];
79 offsetValues[offsetIdx] = indicesVec[i];
80 offsetValues[offsetIdx + 1] = strides[i];
83 int64_t scaler = dstBits / srcBits;
85 builder, loc, addMulMap.
floorDiv(scaler), offsetValues);
87 size_t symbolIndex = 0;
90 for (
unsigned i = 0; i < sourceRank; ++i) {
91 AffineExpr strideExpr = symbols[symbolIndex++];
92 values.push_back(strides[i]);
94 values.push_back(sizes[i]);
102 0, symbolIndex, productExpressions,
111 builder, loc, s0.
floorDiv(scaler), {offset});
114 builder, loc, addMulMap % scaler, offsetValues);
116 return {{adjustBaseOffset, linearizedSize, intraVectorOffset},
126 if (!sizes.empty()) {
133 builder, loc, s0 * s1,
139 std::tie(linearizedMemRefInfo, std::ignore) =
143 return linearizedMemRefInfo;
151 std::vector<Operation *> opUses;
157 if (isa<memref::DeallocOp>(useOp) ||
161 opUses.push_back(useOp);
166 llvm::append_range(uses, opUses);
171 std::vector<Operation *> opToErase;
173 std::vector<Operation *> candidates;
174 if (isa<memref::AllocOp, memref::AllocaOp>(op) &&
176 llvm::append_range(opToErase, candidates);
177 opToErase.push_back(op);
193 for (
int64_t r =
static_cast<int64_t>(strides.size()) - 1; r > 0; --r) {
195 builder, loc, s0 * s1, {strides[r], sizes[r]});
208 while (
auto *op = source.getDefiningOp()) {
209 if (
auto subViewOp = dyn_cast<memref::SubViewOp>(op);
210 subViewOp && subViewOp.hasZeroOffset() && subViewOp.hasUnitStride()) {
213 source = cast<MemrefValue>(subViewOp.getSource());
214 }
else if (
auto castOp = dyn_cast<memref::CastOp>(op)) {
216 source = castOp.getSource();
225 while (
auto *op = source.getDefiningOp()) {
226 if (
auto viewLike = dyn_cast<ViewLikeOpInterface>(op)) {
227 if (source == viewLike.getViewDest()) {
228 source = cast<MemrefValue>(viewLike.getViewSource());
238 memref::ExpandShapeOp expandShapeOp,
241 bool startsInbounds) {
247 assert(!group.empty() &&
"association indices groups cannot be empty");
248 int64_t groupSize = group.size();
249 if (groupSize == 1) {
250 sourceIndices.push_back(
indices[group[0]]);
254 llvm::map_to_vector(group, [&](
int64_t d) {
return destShape[d]; });
256 llvm::map_to_vector(group, [&](
int64_t d) {
return indices[d]; });
257 Value collapsedIndex = affine::AffineLinearizeIndexOp::create(
258 rewriter, loc, groupIndices, groupBasis, startsInbounds);
259 sourceIndices.push_back(collapsedIndex);
264 memref::CollapseShapeOp collapseShapeOp,
267 bool startsInbounds) {
269 auto metadata = memref::ExtractStridedMetadataOp::create(
270 rewriter, loc, collapseShapeOp.getSrc());
272 for (
auto [
index, group] :
273 llvm::zip(
indices, collapseShapeOp.getReassociationIndices())) {
274 assert(!group.empty() &&
"association indices groups cannot be empty");
275 int64_t groupSize = group.size();
277 if (groupSize == 1) {
278 sourceIndices.push_back(
index);
288 trimmedGroup, [&](
int64_t d) {
return sourceSizes[d]; });
289 auto delinearize = affine::AffineDelinearizeIndexOp::create(
290 rewriter, loc,
index, basis, startsInbounds);
291 llvm::append_range(sourceIndices,
delinearize.getResults());
293 if (collapseShapeOp.getReassociationIndices().empty()) {
296 cast<MemRefType>(collapseShapeOp.getViewSource().getType()).getRank();
299 for (
int64_t i = 0; i < srcRank; i++) {
300 sourceIndices.push_back(
309 if (!subViewOp.hasZeroOffset() || !subViewOp.hasUnitStride())
312 MemRefType srcType = subViewOp.getSourceType();
313 MemRefType resType = subViewOp.getType();
314 unsigned srcRank = srcType.getRank();
315 unsigned resRank = resType.getRank();
316 if (srcRank <= resRank ||
indices.size() != resRank)
319 auto droppedDims = subViewOp.getDroppedDims();
320 if (droppedDims.none() || droppedDims.count() != srcRank - resRank)
323 auto mixedSizes = subViewOp.getMixedSizes();
324 if (mixedSizes.size() != srcRank)
327 unsigned resultDim = 0;
328 for (
unsigned sourceDim = 0; sourceDim < srcRank; ++sourceDim) {
329 if (droppedDims.test(sourceDim)) {
331 if (!sizeCst || *sizeCst != 1)
333 sourceIndices.push_back(
337 if (resultDim >=
indices.size())
339 sourceIndices.push_back(
indices[resultDim++]);
341 if (resultDim !=
indices.size())
static int64_t product(ArrayRef< int64_t > vals)
Base type for affine expression.
AffineExpr floorDiv(uint64_t v) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
IntegerAttr getIndexAttr(int64_t value)
AffineExpr getAffineConstantExpr(int64_t constant)
AffineMap getConstantAffineMap(int64_t val)
Returns a single constant result affine map with 0 dimensions and 0 symbols.
MLIRContext * getContext() const
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
This class provides the API for ops that are known to be terminators.
Operation is the basic unit of execution within MLIR.
bool mightHaveTrait()
Returns true if the operation might have the provided trait.
unsigned getNumRegions()
Returns the number of regions held by this operation.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
use_range getUses()
Returns a range of all uses, which is useful for iterating over all uses.
unsigned getNumResults()
Return the number of results held by this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
OpFoldResult makeComposedFoldedAffineMax(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a maximum across the results of applying map to operands,...
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
std::pair< LinearizedMemRefInfo, OpFoldResult > getLinearizedMemRefOffsetAndSize(OpBuilder &builder, Location loc, int srcBits, int dstBits, OpFoldResult offset, ArrayRef< OpFoldResult > sizes, ArrayRef< OpFoldResult > strides, ArrayRef< OpFoldResult > indices={}, LinearizedDivKind sizeDivKind=LinearizedDivKind::Floor)
static bool resultIsNotRead(Operation *op, std::vector< Operation * > &uses)
Returns true if all the uses of op are not read/load.
MemrefValue skipFullyAliasingOperations(MemrefValue source)
Walk up the source chain until an operation that changes/defines the view of memory is found (i....
void eraseDeadAllocAndStores(RewriterBase &rewriter, Operation *parentOp)
Track temporary allocations that are never read from.
MemrefValue skipViewLikeOps(MemrefValue source)
Walk up the source chain until we find an operation that is not a view of the source memref (i....
bool isStaticShapeAndContiguousRowMajor(MemRefType type)
Returns true, if the memref type has static shapes and represents a contiguous chunk of memory.
void resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter, memref::CollapseShapeOp collapseShapeOp, ValueRange indices, SmallVectorImpl< Value > &sourceIndices, bool startsInbounds)
Given the 'indices' of a load/store operation where the memref is a result of a collapse_shape op,...
LinearizedDivKind
Controls how the per-dimension contribution to linearizedSize is divided by dstBits / srcBits when sc...
LogicalResult resolveSourceIndicesRankReducingSubview(Location loc, OpBuilder &b, memref::SubViewOp subViewOp, ValueRange indices, SmallVectorImpl< Value > &sourceIndices)
Given the 'indices' of a load/store operation where the memref is a result of a rank-reducing full su...
static SmallVector< OpFoldResult > computeSuffixProductIRBlockImpl(Location loc, OpBuilder &builder, ArrayRef< OpFoldResult > sizes, OpFoldResult unit)
SmallVector< OpFoldResult > computeSuffixProductIRBlock(Location loc, OpBuilder &builder, ArrayRef< OpFoldResult > sizes)
Given a set of sizes, return the suffix product.
void resolveSourceIndicesExpandShape(Location loc, PatternRewriter &rewriter, memref::ExpandShapeOp expandShapeOp, ValueRange indices, SmallVectorImpl< Value > &sourceIndices, bool startsInbounds)
Given the 'indices' of a load/store operation where the memref is a result of a expand_shape op,...
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
TypedValue< BaseMemRefType > MemrefValue
A value with a memref type.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
bool hasEffect(Operation *op)
Returns "true" if op has an effect of type EffectTy.
void bindSymbolsList(MLIRContext *ctx, MutableArrayRef< AffineExprTy > exprs)
For a memref with offset, sizes and strides, returns the offset, size, and potentially the size padde...