19 #include "llvm/ADT/STLExtras.h"
25 if (!type.hasStaticShape())
35 int64_t runningStride = 1;
36 int64_t curDim = strides.size() - 1;
38 while (curDim >= 0 && strides[curDim] == runningStride) {
39 runningStride *= type.getDimSize(curDim);
44 while (curDim >= 0 && type.getDimSize(curDim) == 1) {
56 unsigned sourceRank = sizes.size();
57 assert(sizes.size() == strides.size() &&
58 "expected as many sizes as strides for a memref");
62 assert(indicesVec.size() == strides.size() &&
63 "expected as many indices as rank of memref");
73 for (
unsigned i = 0; i < sourceRank; ++i) {
74 unsigned offsetIdx = 2 * i;
75 addMulMap = addMulMap + symbols[offsetIdx] * symbols[offsetIdx + 1];
76 offsetValues[offsetIdx] = indicesVec[i];
77 offsetValues[offsetIdx + 1] = strides[i];
79 mulMap = mulMap * symbols[i];
83 int64_t scaler = dstBits / srcBits;
84 addMulMap = addMulMap.
floorDiv(scaler);
88 builder, loc, addMulMap, offsetValues);
96 builder, loc, s0.
floorDiv(scaler), {offset});
98 return {{adjustBaseOffset, linearizedSize}, linearizedIndices};
106 if (!sizes.empty()) {
110 for (
int index = sizes.size() - 1; index > 0; --index) {
112 builder, loc, s0 * s1,
118 std::tie(linearizedMemRefInfo, std::ignore) =
121 return linearizedMemRefInfo;
129 std::vector<Operation *> opUses;
132 if (isa<memref::DeallocOp>(useOp) ||
134 !mlir::hasEffect<MemoryEffects::Read>(useOp)) ||
136 opUses.push_back(useOp);
141 uses.insert(uses.end(), opUses.begin(), opUses.end());
146 std::vector<Operation *> opToErase;
147 parentOp->
walk([&](memref::AllocOp op) {
148 std::vector<Operation *> candidates;
150 opToErase.insert(opToErase.end(), candidates.begin(), candidates.end());
151 opToErase.push_back(op.getOperation());
166 for (int64_t r = strides.size() - 1; r > 0; --r) {
168 builder, loc, s0 * s1, {strides[r], sizes[r]});
181 while (
auto op = source.getDefiningOp()) {
182 if (
auto subViewOp = dyn_cast<memref::SubViewOp>(op);
183 subViewOp && subViewOp.hasZeroOffset() && subViewOp.hasUnitStride()) {
186 source = cast<MemrefValue>(subViewOp.getSource());
187 }
else if (
auto castOp = dyn_cast<memref::CastOp>(op)) {
189 source = castOp.getSource();
198 while (
auto op = source.getDefiningOp()) {
199 if (
auto viewLike = dyn_cast<ViewLikeOpInterface>(op)) {
200 source = cast<MemrefValue>(viewLike.getViewSource());
Base type for affine expression.
AffineExpr floorDiv(uint64_t v) const
IntegerAttr getIndexAttr(int64_t value)
AffineExpr getAffineConstantExpr(int64_t constant)
MLIRContext * getContext() const
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
Operation is the basic unit of execution within MLIR.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
unsigned getNumRegions()
Returns the number of regions held by this operation.
unsigned getNumResults()
Return the number of results held by this operation.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
static bool resultIsNotRead(Operation *op, std::vector< Operation * > &uses)
Returns true if all the uses of op are not read/load.
MemrefValue skipFullyAliasingOperations(MemrefValue source)
Walk up the source chain until an operation that changes/defines the view of memory is found (i....
void eraseDeadAllocAndStores(RewriterBase &rewriter, Operation *parentOp)
std::pair< LinearizedMemRefInfo, OpFoldResult > getLinearizedMemRefOffsetAndSize(OpBuilder &builder, Location loc, int srcBits, int dstBits, OpFoldResult offset, ArrayRef< OpFoldResult > sizes, ArrayRef< OpFoldResult > strides, ArrayRef< OpFoldResult > indices={})
MemrefValue skipViewLikeOps(MemrefValue source)
Walk up the source chain until we find an operation that is not a view of the source memref (i....
bool isStaticShapeAndContiguousRowMajor(MemRefType type)
Returns true, if the memref type has static shapes and represents a contiguous chunk of memory.
static SmallVector< OpFoldResult > computeSuffixProductIRBlockImpl(Location loc, OpBuilder &builder, ArrayRef< OpFoldResult > sizes, OpFoldResult unit)
SmallVector< OpFoldResult > computeSuffixProductIRBlock(Location loc, OpBuilder &builder, ArrayRef< OpFoldResult > sizes)
Given a set of sizes, return the suffix product.
Include the generated interface declarations.
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
TypedValue< BaseMemRefType > MemrefValue
A value with a memref type.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
void bindSymbolsList(MLIRContext *ctx, MutableArrayRef< AffineExprTy > exprs)
For a memref with offset, sizes and strides, returns the offset and size to use for the linearized me...