18 #include "llvm/ADT/STLExtras.h"
24 if (!type.hasStaticShape())
29 if (
failed(type.getStridesAndOffset(strides, offset)))
34 int64_t runningStride = 1;
35 int64_t curDim = strides.size() - 1;
37 while (curDim >= 0 && strides[curDim] == runningStride) {
38 runningStride *= type.getDimSize(curDim);
43 while (curDim >= 0 && type.getDimSize(curDim) == 1) {
55 unsigned sourceRank = sizes.size();
56 assert(sizes.size() == strides.size() &&
57 "expected as many sizes as strides for a memref");
61 assert(indicesVec.size() == strides.size() &&
62 "expected as many indices as rank of memref");
71 for (
unsigned i = 0; i < sourceRank; ++i) {
72 unsigned offsetIdx = 2 * i;
73 addMulMap = addMulMap + symbols[offsetIdx] * symbols[offsetIdx + 1];
74 offsetValues[offsetIdx] = indicesVec[i];
75 offsetValues[offsetIdx + 1] = strides[i];
78 int64_t scaler = dstBits / srcBits;
80 builder, loc, addMulMap.
floorDiv(scaler), offsetValues);
82 size_t symbolIndex = 0;
85 for (
unsigned i = 0; i < sourceRank; ++i) {
86 AffineExpr strideExpr = symbols[symbolIndex++];
87 values.push_back(strides[i]);
89 values.push_back(sizes[i]);
91 productExpressions.push_back((strideExpr * sizeExpr).floorDiv(scaler));
94 0, symbolIndex, productExpressions,
103 builder, loc, s0.
floorDiv(scaler), {offset});
106 builder, loc, addMulMap % scaler, offsetValues);
108 return {{adjustBaseOffset, linearizedSize, intraVectorOffset},
117 if (!sizes.empty()) {
121 for (
int index = sizes.size() - 1; index > 0; --index) {
123 builder, loc, s0 * s1,
129 std::tie(linearizedMemRefInfo, std::ignore) =
132 return linearizedMemRefInfo;
140 std::vector<Operation *> opUses;
143 if (isa<memref::DeallocOp>(useOp) ||
145 !mlir::hasEffect<MemoryEffects::Read>(useOp)) ||
147 opUses.push_back(useOp);
152 llvm::append_range(uses, opUses);
157 std::vector<Operation *> opToErase;
159 std::vector<Operation *> candidates;
160 if (isa<memref::AllocOp, memref::AllocaOp>(op) &&
162 llvm::append_range(opToErase, candidates);
163 opToErase.push_back(op);
179 for (int64_t r = strides.size() - 1; r > 0; --r) {
181 builder, loc, s0 * s1, {strides[r], sizes[r]});
194 while (
auto op = source.getDefiningOp()) {
195 if (
auto subViewOp = dyn_cast<memref::SubViewOp>(op);
196 subViewOp && subViewOp.hasZeroOffset() && subViewOp.hasUnitStride()) {
199 source = cast<MemrefValue>(subViewOp.getSource());
200 }
else if (
auto castOp = dyn_cast<memref::CastOp>(op)) {
202 source = castOp.getSource();
211 while (
auto op = source.getDefiningOp()) {
212 if (
auto viewLike = dyn_cast<ViewLikeOpInterface>(op)) {
213 if (source == viewLike.getViewDest()) {
214 source = cast<MemrefValue>(viewLike.getViewSource());
225 memref::ExpandShapeOp expandShapeOp,
ValueRange indices,
232 assert(!group.empty() &&
"association indices groups cannot be empty");
233 int64_t groupSize = group.size();
234 if (groupSize == 1) {
235 sourceIndices.push_back(indices[group[0]]);
239 llvm::map_to_vector(group, [&](int64_t d) {
return destShape[d]; });
241 llvm::map_to_vector(group, [&](int64_t d) {
return indices[d]; });
242 Value collapsedIndex = affine::AffineLinearizeIndexOp::create(
243 rewriter, loc, groupIndices, groupBasis, startsInbounds);
244 sourceIndices.push_back(collapsedIndex);
251 memref::CollapseShapeOp collapseShapeOp,
255 auto metadata = memref::ExtractStridedMetadataOp::create(
256 rewriter, loc, collapseShapeOp.getSrc());
258 for (
auto [index, group] :
259 llvm::zip(indices, collapseShapeOp.getReassociationIndices())) {
260 assert(!group.empty() &&
"association indices groups cannot be empty");
261 int64_t groupSize = group.size();
263 if (groupSize == 1) {
264 sourceIndices.push_back(index);
269 llvm::map_to_vector(group, [&](int64_t d) {
return sourceSizes[d]; });
270 auto delinearize = affine::AffineDelinearizeIndexOp::create(
271 rewriter, loc, index, basis,
true);
272 llvm::append_range(sourceIndices,
delinearize.getResults());
274 if (collapseShapeOp.getReassociationIndices().empty()) {
277 cast<MemRefType>(collapseShapeOp.getViewSource().getType()).getRank();
280 for (int64_t i = 0; i < srcRank; i++) {
281 sourceIndices.push_back(
Base type for affine expression.
AffineExpr floorDiv(uint64_t v) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
IntegerAttr getIndexAttr(int64_t value)
AffineExpr getAffineConstantExpr(int64_t constant)
AffineMap getConstantAffineMap(int64_t val)
Returns a single constant result affine map with 0 dimensions and 0 symbols.
MLIRContext * getContext() const
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
Operation is the basic unit of execution within MLIR.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
unsigned getNumRegions()
Returns the number of regions held by this operation.
use_range getUses()
Returns a range of all uses, which is useful for iterating over all uses.
unsigned getNumResults()
Return the number of results held by this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
OpFoldResult makeComposedFoldedAffineMax(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a maximum across the results of applying map to operands,...
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
static bool resultIsNotRead(Operation *op, std::vector< Operation * > &uses)
Returns true if all the uses of op are not read/load.
MemrefValue skipFullyAliasingOperations(MemrefValue source)
Walk up the source chain until an operation that changes/defines the view of memory is found (i....
void eraseDeadAllocAndStores(RewriterBase &rewriter, Operation *parentOp)
Track temporary allocations that are never read from.
LogicalResult resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter, memref::CollapseShapeOp collapseShapeOp, ValueRange indices, SmallVectorImpl< Value > &sourceIndices)
Given the 'indices' of a load/store operation where the memref is a result of a collapse_shape op,...
std::pair< LinearizedMemRefInfo, OpFoldResult > getLinearizedMemRefOffsetAndSize(OpBuilder &builder, Location loc, int srcBits, int dstBits, OpFoldResult offset, ArrayRef< OpFoldResult > sizes, ArrayRef< OpFoldResult > strides, ArrayRef< OpFoldResult > indices={})
MemrefValue skipViewLikeOps(MemrefValue source)
Walk up the source chain until we find an operation that is not a view of the source memref (i....
bool isStaticShapeAndContiguousRowMajor(MemRefType type)
Returns true, if the memref type has static shapes and represents a contiguous chunk of memory.
static SmallVector< OpFoldResult > computeSuffixProductIRBlockImpl(Location loc, OpBuilder &builder, ArrayRef< OpFoldResult > sizes, OpFoldResult unit)
LogicalResult resolveSourceIndicesExpandShape(Location loc, PatternRewriter &rewriter, memref::ExpandShapeOp expandShapeOp, ValueRange indices, SmallVectorImpl< Value > &sourceIndices, bool startsInbounds)
Given the 'indices' of a load/store operation where the memref is a result of a expand_shape op,...
SmallVector< OpFoldResult > computeSuffixProductIRBlock(Location loc, OpBuilder &builder, ArrayRef< OpFoldResult > sizes)
Given a set of sizes, return the suffix product.
Include the generated interface declarations.
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
TypedValue< BaseMemRefType > MemrefValue
A value with a memref type.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
void bindSymbolsList(MLIRContext *ctx, MutableArrayRef< AffineExprTy > exprs)
For a memref with offset, sizes and strides, returns the offset, size, and potentially the size padde...